大家的AI
机器学习AI论文
加载中…

学习

🏅我的成就

Ch.04

注意力优化:FlashAttention 与稀疏注意力

稠密(Dense)铺满 N×NN\times NN×N ——表达强但一长就重;Flash 用分块减慢速显存往返;稀疏用局部+全局模式删边。

稠密 vs Flash vs 稀疏

稠密

Q×K
稠密: 所有 (i,j) 都打分 → 表达↑,大 NNN 时 成本↑。

Flash

Q×K
分块: 同一 softmax,分块执行 → HBM 往返↓,体感更快。

稀疏

Q×K
稀疏: 窗口 + 全局锚点 → 活跃对更少;远距靠设计补齐。

三种路径一览

① 找瓶颈(NNN / OOM)→ ② Flash 或稀疏 → ③ 指标验证
① Flash: IO 友好路径跑同一 softmax。
② 稀疏: 减边——长程写进掩码。
③ 稠密: N2N^2N2 放进预算。
④ 工程: OOM → Flash·batch·dtype;效果 → 全局·RAG。
自注意力让 token 以相似度互相融合,但长度 NNN 上来后,分数表约 N×NN \times NN×N,显存与算力会很快顶到 OOM 和高延迟。关键不是“注意力错了”,而是成本随长度涨得太陡。
FlashAttention在不改变 softmax 注意力目标的前提下,用分块和内存层级减少 HBM 往返,让同一套注意在 GPU 上跑得更快。稀疏注意力用局部窗口 + 少量全局锚点等固定模式限制每个 query 看的 key,从结构上少算、少占内存。你会分清:Flash ≈ 同一数学、更高效实现;稀疏 ≈ 改连接、要评测;以及二者如何落在训练与推理里。
核心公式(scaled dot-product / softmax 注意力)单头计算为
Attention(Q,K,V)=softmax(QKTdk)V\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V
先构造缩放后的 logits S=QKT/dkS=QK^T/\sqrt{d_k}S=QKT/dk​​。SijS_{ij}Sij​表示“第 iii 个位置应多关注第 jjj 个键”的原始得分;除以 dk\sqrt{d_k}dk​​ 是为了在头维度 dkd_kdk​ 变大时防止点积过大、softmax 过于尖锐。对每一行 iii 做 softmaxj\mathrm{softmax}_jsoftmaxj​ 得到权重 AijA_{ij}Aij​(∑jAij=1\sum_j A_{ij}=1∑j​Aij​=1),再用这些权重把 VVV 的各行加权求和得到输出的第 iii 行。
写成向量:QQQ 的第 iii 行 qiq_iqi​ 与 KKK 的第 jjj 行 kjk_jkj​ 满足 Sij=(qi⋅kj)/dkS_{ij}=(q_i\cdot k_j)/\sqrt{d_k}Sij​=(qi​⋅kj​)/dk​​,输出行 iii 为 ∑jAijVj\sum_j A_{ij}V_j∑j​Aij​Vj​。Flash 用更少内存搬运实现同一映射;稀疏在 softmax 前对 j∉Sij\notin S_ij∈/Si​ 置 −∞-\infty−∞ 掩码,使有效列集变小。

公式速览

序列长度为 NNN 时,Q,K,VQ,K,VQ,K,V 的每一行对应一个 token 的向量(常见设定下查询/键维度为 dkd_kdk​,值维度为 dvd_vdv​)。先构造缩放相似度(logits)矩阵
S=QKTdkS=\frac{QK^T}{\sqrt{d_k}}S=dk​​QKT​
SijS_{ij}Sij​衡量“查询 iii 与键 jjj 的匹配程度”。除以 dk\sqrt{d_k}dk​​ 是 scaled dot-product:维度升高时点积幅度变大,不做缩放 softmax 容易极端尖锐。
对每一查询行在键方向上 softmax,得到注意力权重 AAA:
A=softmax(S)=softmax(QKTdk),∑j=1NAij=1A=\mathrm{softmax}(S)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right),\quad \sum_{j=1}^{N} A_{ij}=1A=softmax(S)=softmax(dk​​QKT​),j=1∑N​Aij​=1
再用这些权重混合 VVV 的各行。第 iii 行输出是值向量的凸组合:
outputi=∑j=1NAijVj⇔Attention(Q,K,V)=AV\mathrm{output}_i=\sum_{j=1}^{N} A_{ij}V_j \quad\Leftrightarrow\quad \mathrm{Attention}(Q,K,V)=AVoutputi​=j=1∑N​Aij​Vj​⇔Attention(Q,K,V)=AV
合并两步即常用写法:
Attention(Q,K,V)=softmax(QKTdk)V\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V
主要瓶颈来自 QKTQK^TQKT(或 SSS)这张 N×NN\times NN×N 表。计算 QKTQK^TQKT 约 O(N2dk)O(N^2 d_k)O(N2dk​),元素个数约 N2N^2N2;NNN 翻一倍,格子约变四倍。随后的行 softmax 与 AVAVAV 仍与这一平方结构绑定。batch 与多头会再放大实际显存/FLOPs,但“长上下文为什么贵”的根因仍是这张二次网格。
稀疏注意力把“并非每个 jjj 都需要看”写进 softmax 前的 logits。令 E=QKT/dkE=QK^T/\sqrt{d_k}E=QKT/dk​​,对查询 iii 允许的键集合为 SiS_iSi​,在 j∉Sij\notin S_ij∈/Si​ 处设 Eij=−∞E_{ij}=-\inftyEij​=−∞(实现中常用足够小的有限负数)。
softmax(Ei:)j≈0(j∉Si)\mathrm{softmax}(E_{i:})_j \approx 0 \quad (j\notin S_i)softmax(Ei:​)j​≈0(j∈/Si​)
于是 AijA_{ij}Aij​ 只在 SiS_iSi​ 上实质非零,形式上仍是 output=AV\mathrm{output}=AVoutput=AV,但计算与内存围绕允许的位置对展开。
FlashAttention 仍以 softmax(QKT/dk)V\mathrm{softmax}(QK^T/\sqrt{d_k})Vsoftmax(QKT/dk​​)V 的注意力输出为目标,但避免整张 N×NN\times NN×N 长时间驻留慢速显存:分块计算,在片上快速内存中累积 softmax 的归一化统计并融合算子。映射不变,数据搬运更少。

注意力优化:FlashAttention 与稀疏注意力

1. 为何变重 概念: 注意力要一张约 N×NN \times NN×N 的分数表;长度稍涨,工作量常接近平方。类比: 握手。落地: 长 PDF、超长提示、批量推理最容易 OOM / 卡死。
记住: 核心是成本曲线。
2. FlashAttention 减少 HBM↔SRAM 往返,用 tile 在片上做完尽量多步骤。仍是同一套 softmax 注意,属于内核工程。
公式目标不变: softmax(QKTdk)V\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)Vsoftmax(dk​​QKT​)V ——变的是分块与融合,不是映射(在数值意义上)。
记住: 默认叙事是同映射、更快跑。
3. 稀疏注意力 只让 query 看窗内邻居 + 少量全局点等,∣Si∣≪N|S_i| \ll N∣Si​∣≪N。常见写法: 对 logits E=QKT/dkE=QK^T/\sqrt{d_k}E=QKT/dk​​,在 j∉Sij\notin S_ij∈/Si​ 处令 Eij=−∞E_{ij}=-\inftyEij​=−∞(softmax 前),使第 iii 行 softmax 只在 SiS_iSi​ 上有质量。
收益: 省算力、省显存。风险: 远距稀有关系可能被模式挡住。
记住: 本质是掩码/模式设计。
4. 区别 Flash→对齐 dense、加速。稀疏→改连接。实务: 先 Flash 度量;不够再 稀疏 + 评估。

为何重要

成本与规模 注意力常占 GPU 时间与显存峰值。Flash/融合核/更稀掩码让同一权重能服务更长上下文或更大 batch。长上下文 能吞整本书级提示的产品,背后往往是高效注意力,直接牵动体验与单价。精度与速度 每删一条边就省工作,也可能丢远处证据。基准测试远距任务,不够就加全局或 RAG。承接 Ch01–03 Q,K,VQ,K,VQ,K,V 与 softmax 思想不变;这里只谈怎么执行、谁能看谁。

如何用

训练 打开 Flash / SDPA / 融合注意,盯峰值显存与步耗时。同一卡上往往可以略增 batch/序列。推理 KV 缓存 + 高效核决定 TTFT 与解码;体感延迟最关键。模式 文档:局部+段落锚点;代码:括号/作用域可能需要更宽窗口。数据决定掩码。
OOM
①长·批·精度 →
②Flash真的开吗 →
③仍紧 再试 切块/稀疏/RAG。先弄清是否真需要全连接。

小结

一句话 — FlashAttention 让同一套 softmax 注意力在 GPU 上少搬内存、多分块,跑得更快;稀疏注意力用固定模式让每个 query 只看一部分 key,从根上少算、少占内存。
为什么一长就吃紧 — 长度 NNN 时分数表大约 N×NN \times NN×N。不是模型“坏了”,而是组合太多,显存和时间观感到瓶颈很正常。
Flash 用人话说 — 大块数据停在慢 HBM 里来回读写给不起,就把它拆成能在片上快速缓存里算完的小块。数学目标对齐标准注意,赚的是工程效率。
稀疏 用人话说 — 只让注意力看近邻 + 少量全局锚点之类的结构。省算力,但要留心:很远才出现的线索若不在模式里,效果可能掉;必要时配合 RAG 或调模式。
落地建议 — 先尽量用 Flash 扛住显存与延迟,再按需叠 稀疏/分段/检索。Ch01–03 的公式不变,本章教你执行层面的省与快。
核心公式(复习)— 分三步:logits S=QKT/dkS=QK^T/\sqrt{d_k}S=QKT/dk​​,按行 softmax 得 AAA,输出 AVAVAV。
Attention(Q,K,V)=softmax(QKTdk)V=AV\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V=AVAttention(Q,K,V)=softmax(dk​​QKT​)V=AV
位置 iii 的输出行是 ∑jAijVj\sum_j A_{ij}V_j∑j​Aij​Vj​,即用注意力权重混合 value 向量。QKTQK^TQKT 与 AAA 都涉及 N×NN\times NN×N,所以长度 NNN是成本核心。
本章要点:
• 痛点在 ~N2N^2N2 规模的分数工作。
• Flash=分块/内核,同一注意力。
• 稀疏=模式 + 验证。

解题提示

小结 — FlashAttention 用 IO 友好内核加速同一套 softmax 注意力;稀疏注意力用模式限制每个 query 看的 key,从而减少连接与计算。NNN 变大时 N2N^2N2 常成为瓶颈,实务中在 Flash 与稀疏模式之间权衡显存、延迟与效果。
速记: Flash=同样的注意,更快路径;稀疏=少连边,设计先行。OOM 先查 Flash 是否开启 与 长度/batch。
落地清单:
• 要保持同一数学只提速 → Flash 优先。
• 要减连接 → 稀疏 + 评测。
• 远距线索重要 → 改模式或加 RAG。
  • 题型概念选择
  • 填写内容
    ①
    ②
    ③ 序号 → 1–3
  • 题型对/错
  • 填写内容对/错 → 1/0
  • 题型情景
  • 填写内容最贴合项 → 1–3
  • 题型投票和
  • 填写内容01 向量中 1 的个数 → 整数
  • 题型求和
  • 填写内容数值总和 → 整数
  • 题型网格
  • 填写内容边长 nnn 的方格数 n2n^2n2 → 整数
  • 题型集成/权衡
  • 填写内容最贴切描述 → 1–3
题型填写内容
概念选择
①
②
③ 序号 → 1–3
对/错对/错 → 1/0
情景最贴合项 → 1–3
投票和01 向量中 1 的个数 → 整数
求和数值总和 → 整数
网格边长 nnn 的方格数 n2n^2n2 → 整数
集成/权衡最贴切描述 → 1–3
例(概念) “FlashAttention 的目标最接近?
①只做近似
②以 IO 友好方式加速 dense softmax 注意力
③去掉注意力” → 2

例(判断) “分块有助于利用片上内存。对填 1,错填 0” → 1

例(情景) “要求数学上等价输出:
①精确 Flash
②删所有 Q-K 边
③只升学习率” → 1

例(投票) [1,1,1,1,0] 中 1 的个数 → 4

例(求和) [7,8,7] → 22

例(网格) 边长 12 → 144144144 → 144

例(集成) “稀疏常见风险?
② 或损失少见的长距依赖” → 2