Ch.04
アテンション最適化:FlashAttentionとスパースアテンション
密(Dense)は全面で表現力は高いが長文で重い。FlashはタイルでHBM往復削減。スパースは局所+グローバルで辺を削る。
密 vs Flash vs スパース
密
密: 全(i,j)をスコア→表現力↑、大でコスト↑。
Flash
タイル: 同じsoftmaxをタイルで—HBM往復↓、体感速度↑。
スパース
スパース: 窓+少数グローバル→有効位置↓、長距離は設計で補う。
三つの選択
①ボトルネック(,OOM)→②Flash/パターン→③検証
① Flash: IO効率で同じsoftmax。
② スパース: 辺削減—長距離をマスクに。
③ 密: を常に前提化。
④ 運用: OOM→Flash·バッチ·dtype / 品質→グローバル·RAG。
長さの自己アテンションは、トークン間スコアがおおむねになり、メモリ・演算が急増します。長文や大バッチでOOMや遅延が出るのは「注意が悪い」のではなく、長さに対するコスト曲線が急だからです。
FlashAttentionは同じsoftmax注意を、タイリングとメモリ階層に合わせてHBM往復を減らす実装で速めます。スパース注意は全キーを見せず、局所窓+少数グローバルなどパターンで辺を削ります。Flash=実装で同じ計算を速く、スパース=つながり設計で計算を減らす、この二本立てを学習・推論に結びつけます。
核となる数式 (scaled dot-product / softmax注意) 1ヘッドは
まずスケール済みロジット を作ります。は「位置がキーをどれだけ見るか」の生スコアで、で割るのは次元が大きいと内積が大きくなりすぎ、softmaxが極端に尖りやすいのを和らげるためです。各行でにより重み()を得て、の行を重み付き和して出力の行を作ります。
ベクトルで言うとの行との行から、出力行はです。Flashはこの写像をメモリ移動を抑えて計算し、スパースはsoftmax前にへマスクして実質的な列集合だけ残します。
数式の要点
系列長のとき、の各行は1トークンに対応するベクトル(通常は次元、は)。まずスケール済み類似度(ロジット)行列を作ります。
は「クエリがキーとどれだけ似ているか」のスコアです。はスケールド内積:次元が大きいと内積が大きくなり、softmaxが極端に尖りやすいのを抑えます。
各行でキー方向にsoftmaxし、注意重みを得ます。
最後にの行を混ぜます。出力の行は値ベクトルの加重和です。
2段を1行にまとめると次の形です。
スパース注意は「すべてのは不要」をsoftmax前のロジットに書き込みます。とし、クエリの許可キー集合をとすると、ので(実装では十分小さい負の値)とします。
するとは実質上だけ非ゼロとなり、式は同じくですが計算・メモリは許可ペア中心になります。
アテンション最適化: FlashAttentionとスパースアテンション
1. 重くなる理由 概念: 長さでスコア表は約。直感: 全対全はだいたい二乗に近づく。実務: 長PDF・巨大プロンプトでOOMや遅延の主因。
覚えておく: 問題の中心はコスト曲線です。
2. FlashAttention HBMとSRAMの往復を減らし、タイルをオンチップで回す。同じsoftmax注意をカーネルで速くする話です。
数式: 変えないのは — 中間の出し方・場所だけがメモリ効率寄りになります。
覚えておく: 基本は同じ計算・速い経路です。
3. スパース 窓+少数グローバルなど許可パターンだけ見せ、。典型: ロジット に対し で (softmax前)→ 行 の質量は に集中。
利点: メモリ・演算削減。注意: 珍しい長距離が要るタスクではパターン外が痛い。
重要性
コストと規模 注意はGPU時間とピークVRAMの中心になりがち。Flash等で同じ重みをより長い文脈やより大きいバッチで回せます。長文 書籍級プロンプトの背後には効率注意があり、プロダクト競争力に直結します。精度と速度 辺を減らすほど速いが証拠を落とす恐れ。ベンチで長距離を確認し、グローバル拡大やRAGを検討。Ch01–03 とsoftmaxは同じ。ここは実行と誰が誰を見るかの話。
使い方
学習 Flash/SDPAをオンにしピークVRAMとステップ時間を見る。同じGPUでバッチ/文脈を少し伸ばせることが多い。推論 KVキャッシュと高速カーネルがTTFTとデコードを決める。体感速度に直結。パターン 文書:局所+段落アンカー。コード:スコープのため窓拡大。データがマスクを決める。
OOM
①文長・バッチ・dtype →
②Flash有効か →
③ダメなら チャンク/スパース/RAG。本当に全対全が要か先に整理。
まとめ
ひと言で — FlashAttentionは同じsoftmax注意をタイル単位で速いメモリに載せ替え、メモリの往復を減らして速くします。スパース注意は全部のキーを見ないようにパターンを決め、演算とメモリを抑えます。
長さがきつい理由 — 長さならスコアはおおよそ。文が伸びるほど負担がぐっと増えるのは自然な結果です。
Flashをやさしく — 数式そのものより、GPUのHBM/SRAMに合わせて小さく分割して計算する発想です。出力の意味は通常の注意と揃えるのが基本で、実装の勝ちです。
スパースをやさしく — 近所や代表トークンなど、見ていい場所を決めます。軽くなる代わりに、離れたところの証拠が要るタスクではパターンを誤ると不利になります。
実務 — まずFlashでボトルネックを下げ、必要ならスパースやRAGを検討。Ch01–03の考え方は同じで、ここは実行の工夫の章です。
核となる数式(復習)— 段階はロジット 、行softmaxで、出力。
位置の出力行は、つまり値ベクトルを注意重みで混ぜたものと覚えるとよいです。とがを持つので長さがコストの中心です。
本章の核:
• 負担の素は長さに伴う近傍コスト。
• Flash=カーネル/タイルで同じ注意を高速に。
• スパース=パターン設計と検証が命。
解法のヒント
まとめ — FlashAttentionは同じsoftmax注意をメモリ階層に合わせたカーネルで高速化し、スパース注意はパターンで見るキーを制限して接続と計算を減らします。が大きいとがボトルネックになりがちなので、OOMと品質のバランスでFlashかスパース設計を選びます。
覚え方: Flash=同じ計算を速い経路で、スパース=辺を減らす設計。OOMならFlashオンと文長・バッチを先に疑う。
現場チェック:
• 同じ数学を保ち速く → Flash優先。
• 演算を減らす → スパース+評価。
• 遠い証拠が重要 → パターン見直し・RAG。
- タイプ概念選択
- 入れる値①②③の番号 → 1–3
- タイプO/X
- 入れる値正/誤 → 1/0
- タイプシナリオ
- 入れる値状況に合う選択 → 1–3
- タイプ投票の和
- 入れる値ベクトルの1の個数 → 整数
- タイプ値の和
- 入れる値数列の合計 → 整数
- タイプ格子
- 入れる値一辺のマス数 → 整数
- タイプアンサンブル等
- 入れる値最も適切な説明 → 1–3
| タイプ | 入れる値 |
|---|---|
| 概念選択 | ① ② ③の番号 → 1–3 |
| O/X | 正/誤 → 1/0 |
| シナリオ | 状況に合う選択 → 1–3 |
| 投票の和 | ベクトルの1の個数 → 整数 |
| 値の和 | 数列の合計 → 整数 |
| 格子 | 一辺のマス数 → 整数 |
| アンサンブル等 | 最も適切な説明 → 1–3 |
例(概念)
「FlashAttentionの目的に近いのは?
①近似のみ
②IOを意識した実装でdense softmax注意を高速化
③注意を完全削除」
同じ注意を速く → 2
例(O/X)
「タイリングはオンチップメモリ活用に役立つ。正=1、誤=0」
正 → 1
例(シナリオ)
「数学的に同一の注意出力が必要。
①正確なFlash
②全辺削除
③LR10倍のみ」
① → 1
例(投票) votes=[1,1,1,1,0] の1の個数 → 4
例(合計) [7,8,7] の和 → 22
例(格子) 一辺12 のマス数 → → 144
例(アンサンブル) スパースのリスクに近いのは?
② 稀な長距離喪失の可能性 → 2