みんなのAI
機械学習AI論文
読み込み中…

学ぶ

🏅マイ実績

Ch.04

アテンション最適化:FlashAttentionとスパースアテンション

密(Dense)は全面N×NN\times NN×Nで表現力は高いが長文で重い。FlashはタイルでHBM往復削減。スパースは局所+グローバルで辺を削る。

密 vs Flash vs スパース

密

Q×K
密: 全(i,j)をスコア→表現力↑、NNN大でコスト↑。

Flash

Q×K
タイル: 同じsoftmaxをタイルで—HBM往復↓、体感速度↑。

スパース

Q×K
スパース: 窓+少数グローバル→有効位置↓、長距離は設計で補う。

三つの選択

①ボトルネック(NNN,OOM)→②Flash/パターン→③検証
① Flash: IO効率で同じsoftmax。
② スパース: 辺削減—長距離をマスクに。
③ 密: N2N^2N2を常に前提化。
④ 運用: OOM→Flash·バッチ·dtype / 品質→グローバル·RAG。
長さNNNの自己アテンションは、トークン間スコアがおおむねN×NN \times NN×Nになり、メモリ・演算が急増します。長文や大バッチでOOMや遅延が出るのは「注意が悪い」のではなく、長さに対するコスト曲線が急だからです。
FlashAttentionは同じsoftmax注意を、タイリングとメモリ階層に合わせてHBM往復を減らす実装で速めます。スパース注意は全キーを見せず、局所窓+少数グローバルなどパターンで辺を削ります。Flash=実装で同じ計算を速く、スパース=つながり設計で計算を減らす、この二本立てを学習・推論に結びつけます。
核となる数式 (scaled dot-product / softmax注意) 1ヘッドは
Attention(Q,K,V)=softmax(QKTdk)V\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V
まずスケール済みロジット S=QKT/dkS=QK^T/\sqrt{d_k}S=QKT/dk​​ を作ります。SijS_{ij}Sij​は「位置iiiがキーjjjをどれだけ見るか」の生スコアで、dk\sqrt{d_k}dk​​で割るのは次元dkd_kdk​が大きいと内積が大きくなりすぎ、softmaxが極端に尖りやすいのを和らげるためです。各行iiiでsoftmaxj\mathrm{softmax}_jsoftmaxj​により重みAijA_{ij}Aij​(∑jAij=1\sum_j A_{ij}=1∑j​Aij​=1)を得て、VVVの行を重み付き和して出力の行iiiを作ります。
ベクトルで言うとQQQのiii行qiq_iqi​とKKKのjjj行kjk_jkj​からSij=(qi⋅kj)/dkS_{ij}=(q_i\cdot k_j)/\sqrt{d_k}Sij​=(qi​⋅kj​)/dk​​、出力行iiiは∑jAijVj\sum_j A_{ij}V_j∑j​Aij​Vj​です。Flashはこの写像をメモリ移動を抑えて計算し、スパースはsoftmax前にj∉Sij\notin S_ij∈/Si​へ−∞-\infty−∞マスクして実質的な列集合だけ残します。

数式の要点

系列長NNNのとき、Q,K,VQ,K,VQ,K,Vの各行は1トークンに対応するベクトル(通常Q,KQ,KQ,Kは次元dkd_kdk​、VVVはdvd_vdv​)。まずスケール済み類似度(ロジット)行列を作ります。
S=QKTdkS=\frac{QK^T}{\sqrt{d_k}}S=dk​​QKT​
SijS_{ij}Sij​は「クエリiiiがキーjjjとどれだけ似ているか」のスコアです。1/dk1/\sqrt{d_k}1/dk​​はスケールド内積:次元が大きいと内積が大きくなり、softmaxが極端に尖りやすいのを抑えます。
各行でキー方向にsoftmaxし、注意重みAAAを得ます。
A=softmax(S)=softmax(QKTdk),∑j=1NAij=1A=\mathrm{softmax}(S)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right),\quad \sum_{j=1}^{N} A_{ij}=1A=softmax(S)=softmax(dk​​QKT​),j=1∑N​Aij​=1
最後にVVVの行を混ぜます。出力の行iiiは値ベクトルの加重和です。
outputi=∑j=1NAijVj⇔Attention(Q,K,V)=AV\mathrm{output}_i=\sum_{j=1}^{N} A_{ij}V_j \quad\Leftrightarrow\quad \mathrm{Attention}(Q,K,V)=AVoutputi​=j=1∑N​Aij​Vj​⇔Attention(Q,K,V)=AV
2段を1行にまとめると次の形です。
Attention(Q,K,V)=softmax(QKTdk)V\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V
コストの背骨はQKTQK^TQKT(またはSSS)のN×NN\times NN×Nです。行列積はおおよそO(N2dk)O(N^2 d_k)O(N2dk​)、要素数はN2N^2N2。NNNを2倍にすると概ね4倍のセルです。行softmaxとAVAVAVも同じN2N^2N2構造の仕事が続きます。バッチ・ヘッドで実VRAM/FLOPsは増えますが、「長文で効く」本質はこの二乗グリッドです。
スパース注意は「すべてのjjjは不要」をsoftmax前のロジットに書き込みます。E=QKT/dkE=QK^T/\sqrt{d_k}E=QKT/dk​​とし、クエリiiiの許可キー集合をSiS_iSi​とすると、j∉Sij\notin S_ij∈/Si​の(i,j)(i,j)(i,j)でEij=−∞E_{ij}=-\inftyEij​=−∞(実装では十分小さい負の値)とします。
softmax(Ei:)j≈0(j∉Si)\mathrm{softmax}(E_{i:})_j \approx 0 \quad (j\notin S_i)softmax(Ei:​)j​≈0(j∈/Si​)
するとAijA_{ij}Aij​は実質SiS_iSi​上だけ非ゼロとなり、式は同じくoutput=AV\mathrm{output}=AVoutput=AVですが計算・メモリは許可ペア中心になります。
FlashAttentionはsoftmax(QKT/dk)V\mathrm{softmax}(QK^T/\sqrt{d_k})Vsoftmax(QKT/dk​​)Vが与える注意出力を同じ目標にしつつ、全体N×NN\times NN×Nを遅いメモリに常駐させず、タイル単位で部分softmax統計を速いオンチップで積み上げます。何を混ぜるかは同じで、いつどこに置くかが違う実装です。

アテンション最適化: FlashAttentionとスパースアテンション

1. 重くなる理由 概念: 長さNNNでスコア表は約N×NN \times NN×N。直感: 全対全はだいたい二乗に近づく。実務: 長PDF・巨大プロンプトでOOMや遅延の主因。
覚えておく: 問題の中心はコスト曲線です。
2. FlashAttention HBMとSRAMの往復を減らし、タイルをオンチップで回す。同じsoftmax注意をカーネルで速くする話です。
数式: 変えないのは softmax(QKTdk)V\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)Vsoftmax(dk​​QKT​)V — 中間の出し方・場所だけがメモリ効率寄りになります。
覚えておく: 基本は同じ計算・速い経路です。
3. スパース 窓+少数グローバルなど許可パターンだけ見せ、∣Si∣≪N|S_i| \ll N∣Si​∣≪N。典型: ロジット E=QKT/dkE=QK^T/\sqrt{d_k}E=QKT/dk​​ に対し j∉Sij\notin S_ij∈/Si​ で Eij=−∞E_{ij}=-\inftyEij​=−∞(softmax前)→ 行 iii の質量は SiS_iSi​ に集中。
利点: メモリ・演算削減。注意: 珍しい長距離が要るタスクではパターン外が痛い。
4. 違い Flash=出力を揃えつつ高速化。スパース=接続変更。現場: まずFlash、足りなければスパース+評価。

重要性

コストと規模 注意はGPU時間とピークVRAMの中心になりがち。Flash等で同じ重みをより長い文脈やより大きいバッチで回せます。長文 書籍級プロンプトの背後には効率注意があり、プロダクト競争力に直結します。精度と速度 辺を減らすほど速いが証拠を落とす恐れ。ベンチで長距離を確認し、グローバル拡大やRAGを検討。Ch01–03 Q,K,VQ,K,VQ,K,V とsoftmaxは同じ。ここは実行と誰が誰を見るかの話。

使い方

学習 Flash/SDPAをオンにしピークVRAMとステップ時間を見る。同じGPUでバッチ/文脈を少し伸ばせることが多い。推論 KVキャッシュと高速カーネルがTTFTとデコードを決める。体感速度に直結。パターン 文書:局所+段落アンカー。コード:スコープのため窓拡大。データがマスクを決める。
OOM
①文長・バッチ・dtype →
②Flash有効か →
③ダメなら チャンク/スパース/RAG。本当に全対全が要か先に整理。

まとめ

ひと言で — FlashAttentionは同じsoftmax注意をタイル単位で速いメモリに載せ替え、メモリの往復を減らして速くします。スパース注意は全部のキーを見ないようにパターンを決め、演算とメモリを抑えます。
長さがきつい理由 — 長さNNNならスコアはおおよそN×NN \times NN×N。文が伸びるほど負担がぐっと増えるのは自然な結果です。
Flashをやさしく — 数式そのものより、GPUのHBM/SRAMに合わせて小さく分割して計算する発想です。出力の意味は通常の注意と揃えるのが基本で、実装の勝ちです。
スパースをやさしく — 近所や代表トークンなど、見ていい場所を決めます。軽くなる代わりに、離れたところの証拠が要るタスクではパターンを誤ると不利になります。
実務 — まずFlashでボトルネックを下げ、必要ならスパースやRAGを検討。Ch01–03の考え方は同じで、ここは実行の工夫の章です。
核となる数式(復習)— 段階はロジット S=QKT/dkS=QK^T/\sqrt{d_k}S=QKT/dk​​、行softmaxでAAA、出力AVAVAV。
Attention(Q,K,V)=softmax(QKTdk)V=AV\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V=AVAttention(Q,K,V)=softmax(dk​​QKT​)V=AV
位置iiiの出力行は∑jAijVj\sum_j A_{ij}V_j∑j​Aij​Vj​、つまり値ベクトルを注意重みで混ぜたものと覚えるとよいです。QKTQK^TQKTとAAAがN×NN\times NN×Nを持つので長さNNNがコストの中心です。
本章の核:
• 負担の素は長さに伴うN2N^2N2近傍コスト。
• Flash=カーネル/タイルで同じ注意を高速に。
• スパース=パターン設計と検証が命。

解法のヒント

まとめ — FlashAttentionは同じsoftmax注意をメモリ階層に合わせたカーネルで高速化し、スパース注意はパターンで見るキーを制限して接続と計算を減らします。NNNが大きいとN2N^2N2がボトルネックになりがちなので、OOMと品質のバランスでFlashかスパース設計を選びます。
覚え方: Flash=同じ計算を速い経路で、スパース=辺を減らす設計。OOMならFlashオンと文長・バッチを先に疑う。
現場チェック:
• 同じ数学を保ち速く → Flash優先。
• 演算を減らす → スパース+評価。
• 遠い証拠が重要 → パターン見直し・RAG。
  • タイプ概念選択
  • 入れる値
    ①
    ②
    ③の番号 → 1–3
  • タイプO/X
  • 入れる値正/誤 → 1/0
  • タイプシナリオ
  • 入れる値状況に合う選択 → 1–3
  • タイプ投票の和
  • 入れる値ベクトルの1の個数 → 整数
  • タイプ値の和
  • 入れる値数列の合計 → 整数
  • タイプ格子
  • 入れる値一辺nnnのマス数 n2n^2n2 → 整数
  • タイプアンサンブル等
  • 入れる値最も適切な説明 → 1–3
タイプ入れる値
概念選択
①
②
③の番号 → 1–3
O/X正/誤 → 1/0
シナリオ状況に合う選択 → 1–3
投票の和ベクトルの1の個数 → 整数
値の和数列の合計 → 整数
格子一辺nnnのマス数 n2n^2n2 → 整数
アンサンブル等最も適切な説明 → 1–3
例(概念)
「FlashAttentionの目的に近いのは?
①近似のみ
②IOを意識した実装でdense softmax注意を高速化
③注意を完全削除」
同じ注意を速く → 2

例(O/X)
「タイリングはオンチップメモリ活用に役立つ。正=1、誤=0」
正 → 1

例(シナリオ)
「数学的に同一の注意出力が必要。
①正確なFlash
②全辺削除
③LR10倍のみ」
① → 1

例(投票) votes=[1,1,1,1,0] の1の個数 → 4

例(合計) [7,8,7] の和 → 22

例(格子) 一辺12 のマス数 → 144144144 → 144

例(アンサンブル) スパースのリスクに近いのは?
② 稀な長距離喪失の可能性 → 2