Ch.04
注意力优化:FlashAttention 与稀疏注意力
稠密(Dense)铺满 ——表达强但一长就重;Flash 用分块减慢速显存往返;稀疏用局部+全局模式删边。
稠密 vs Flash vs 稀疏
稠密
稠密: 所有 (i,j) 都打分 → 表达↑,大 时 成本↑。
Flash
分块: 同一 softmax,分块执行 → HBM 往返↓,体感更快。
稀疏
稀疏: 窗口 + 全局锚点 → 活跃对更少;远距靠设计补齐。
三种路径一览
① 找瓶颈( / OOM)→ ② Flash 或稀疏 → ③ 指标验证
① Flash: IO 友好路径跑同一 softmax。
② 稀疏: 减边——长程写进掩码。
③ 稠密: 放进预算。
④ 工程: OOM → Flash·batch·dtype;效果 → 全局·RAG。
自注意力让 token 以相似度互相融合,但长度 上来后,分数表约 ,显存与算力会很快顶到 OOM 和高延迟。关键不是“注意力错了”,而是成本随长度涨得太陡。
FlashAttention在不改变 softmax 注意力目标的前提下,用分块和内存层级减少 HBM 往返,让同一套注意在 GPU 上跑得更快。稀疏注意力用局部窗口 + 少量全局锚点等固定模式限制每个 query 看的 key,从结构上少算、少占内存。你会分清:Flash ≈ 同一数学、更高效实现;稀疏 ≈ 改连接、要评测;以及二者如何落在训练与推理里。
核心公式(scaled dot-product / softmax 注意力)单头计算为
先构造缩放后的 logits 。表示“第 个位置应多关注第 个键”的原始得分;除以 是为了在头维度 变大时防止点积过大、softmax 过于尖锐。对每一行 做 得到权重 (),再用这些权重把 的各行加权求和得到输出的第 行。
写成向量: 的第 行 与 的第 行 满足 ,输出行 为 。Flash 用更少内存搬运实现同一映射;稀疏在 softmax 前对 置 掩码,使有效列集变小。
公式速览
序列长度为 时, 的每一行对应一个 token 的向量(常见设定下查询/键维度为 ,值维度为 )。先构造缩放相似度(logits)矩阵
衡量“查询 与键 的匹配程度”。除以 是 scaled dot-product:维度升高时点积幅度变大,不做缩放 softmax 容易极端尖锐。
对每一查询行在键方向上 softmax,得到注意力权重 :
再用这些权重混合 的各行。第 行输出是值向量的凸组合:
合并两步即常用写法:
稀疏注意力把“并非每个 都需要看”写进 softmax 前的 logits。令 ,对查询 允许的键集合为 ,在 处设 (实现中常用足够小的有限负数)。
于是 只在 上实质非零,形式上仍是 ,但计算与内存围绕允许的位置对展开。
注意力优化:FlashAttention 与稀疏注意力
1. 为何变重 概念: 注意力要一张约 的分数表;长度稍涨,工作量常接近平方。类比: 握手。落地: 长 PDF、超长提示、批量推理最容易 OOM / 卡死。
记住: 核心是成本曲线。
2. FlashAttention 减少 HBM↔SRAM 往返,用 tile 在片上做完尽量多步骤。仍是同一套 softmax 注意,属于内核工程。
公式目标不变: ——变的是分块与融合,不是映射(在数值意义上)。
记住: 默认叙事是同映射、更快跑。
3. 稀疏注意力 只让 query 看窗内邻居 + 少量全局点等,。常见写法: 对 logits ,在 处令 (softmax 前),使第 行 softmax 只在 上有质量。
收益: 省算力、省显存。风险: 远距稀有关系可能被模式挡住。
记住: 本质是掩码/模式设计。
为何重要
成本与规模 注意力常占 GPU 时间与显存峰值。Flash/融合核/更稀掩码让同一权重能服务更长上下文或更大 batch。长上下文 能吞整本书级提示的产品,背后往往是高效注意力,直接牵动体验与单价。精度与速度 每删一条边就省工作,也可能丢远处证据。基准测试远距任务,不够就加全局或 RAG。承接 Ch01–03 与 softmax 思想不变;这里只谈怎么执行、谁能看谁。
如何用
训练 打开 Flash / SDPA / 融合注意,盯峰值显存与步耗时。同一卡上往往可以略增 batch/序列。推理 KV 缓存 + 高效核决定 TTFT 与解码;体感延迟最关键。模式 文档:局部+段落锚点;代码:括号/作用域可能需要更宽窗口。数据决定掩码。
OOM
①长·批·精度 →
②Flash真的开吗 →
③仍紧 再试 切块/稀疏/RAG。先弄清是否真需要全连接。
小结
一句话 — FlashAttention 让同一套 softmax 注意力在 GPU 上少搬内存、多分块,跑得更快;稀疏注意力用固定模式让每个 query 只看一部分 key,从根上少算、少占内存。
为什么一长就吃紧 — 长度 时分数表大约 。不是模型“坏了”,而是组合太多,显存和时间观感到瓶颈很正常。
Flash 用人话说 — 大块数据停在慢 HBM 里来回读写给不起,就把它拆成能在片上快速缓存里算完的小块。数学目标对齐标准注意,赚的是工程效率。
稀疏 用人话说 — 只让注意力看近邻 + 少量全局锚点之类的结构。省算力,但要留心:很远才出现的线索若不在模式里,效果可能掉;必要时配合 RAG 或调模式。
落地建议 — 先尽量用 Flash 扛住显存与延迟,再按需叠 稀疏/分段/检索。Ch01–03 的公式不变,本章教你执行层面的省与快。
核心公式(复习)— 分三步:logits ,按行 softmax 得 ,输出 。
位置 的输出行是 ,即用注意力权重混合 value 向量。 与 都涉及 ,所以长度 是成本核心。
本章要点:
• 痛点在 ~ 规模的分数工作。
• Flash=分块/内核,同一注意力。
• 稀疏=模式 + 验证。
解题提示
小结 — FlashAttention 用 IO 友好内核加速同一套 softmax 注意力;稀疏注意力用模式限制每个 query 看的 key,从而减少连接与计算。 变大时 常成为瓶颈,实务中在 Flash 与稀疏模式之间权衡显存、延迟与效果。
速记: Flash=同样的注意,更快路径;稀疏=少连边,设计先行。OOM 先查 Flash 是否开启 与 长度/batch。
落地清单:
• 要保持同一数学只提速 → Flash 优先。
• 要减连接 → 稀疏 + 评测。
• 远距线索重要 → 改模式或加 RAG。
- 题型概念选择
- 填写内容①②③ 序号 → 1–3
- 题型对/错
- 填写内容对/错 → 1/0
- 题型情景
- 填写内容最贴合项 → 1–3
- 题型投票和
- 填写内容01 向量中 1 的个数 → 整数
- 题型求和
- 填写内容数值总和 → 整数
- 题型网格
- 填写内容边长 的方格数 → 整数
- 题型集成/权衡
- 填写内容最贴切描述 → 1–3
| 题型 | 填写内容 |
|---|---|
| 概念选择 | ① ② ③ 序号 → 1–3 |
| 对/错 | 对/错 → 1/0 |
| 情景 | 最贴合项 → 1–3 |
| 投票和 | 01 向量中 1 的个数 → 整数 |
| 求和 | 数值总和 → 整数 |
| 网格 | 边长 的方格数 → 整数 |
| 集成/权衡 | 最贴切描述 → 1–3 |
例(概念) “FlashAttention 的目标最接近?
①只做近似
②以 IO 友好方式加速 dense softmax 注意力
③去掉注意力” → 2
例(判断) “分块有助于利用片上内存。对填 1,错填 0” → 1
例(情景) “要求数学上等价输出:
①精确 Flash
②删所有 Q-K 边
③只升学习率” → 1
例(投票) [1,1,1,1,0] 中 1 的个数 → 4
例(求和) [7,8,7] → 22
例(网格) 边长 12 → → 144
例(集成) “稀疏常见风险?
② 或损失少见的长距依赖” → 2