Everyone's AI
Machine learningAI Papers
Loading...

Learn

🏅My achievements

Ch.04

Attention Optimization: FlashAttention and Sparse Attention

Dense fills the full N×NN\times NN×N grid—expressive but heavy at long NNN. Flash tiles the same math to cut HBM traffic. Sparse drops edges via local + global patterns.

Dense vs Flash vs Sparse

Dense

Q×K
Dense: Every (i,j) pair scored → expressivity ↑, high NNN → cost ↑.

Flash

Q×K
Tile: Same softmax, tile-wise → fewer HBM trips, speed ↑.

Sparse

Q×K
Sparse: Window + globals → fewer active products; long-range by design.

Three choices at a glance

① Bottleneck (NNN, OOM) → ② Flash / sparse → ③ validate metrics
① Flash: IO-aware path for identical softmax.
② Sparse: cut edges—long range in the mask.
③ Dense: budget N2N^2N2 always.
④ Ops: OOM → Flash·batch·dtype / quality → globals·RAG.
Self-attention is a core trick for mixing tokens, but as length NNN grows, the score grid (~N×NN \times NN×N) makes memory and FLOPs spike—long inputs or big batches slam into OOM and latency. The issue is not “attention is wrong” but cost grows too fast with length.
FlashAttention keeps the same softmax attention math and wins on tiling + memory hierarchy (fewer slow HBM round-trips, more work in fast on-chip memory). Sparse attention shrinks the wiring: each query attends only to a pattern (e.g. local window + a few globals). You’ll see when to pick Flash (same mapping, faster run) vs sparse (fewer edges, design + validation) and how both show up in training and serving.
Core formula (scaled dot-product / softmax attention). One head computes
Attention(Q,K,V)=softmax(QKTdk)V\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V
Build scaled logits S=QKT/dkS=QK^T/\sqrt{d_k}S=QKT/dk​​. Entry SijS_{ij}Sij​ is a raw “how much should position iii look at key jjj?” score; dividing by dk\sqrt{d_k}dk​​ keeps dot products from growing too large as the head dimension dkd_kdk​ grows, so softmax does not collapse to one-hot too harshly. For each row iii, row-softmax gives weights AijA_{ij}Aij​ with ∑jAij=1\sum_j A_{ij}=1∑j​Aij​=1, and those weights blend the rows of VVV to form output row iii.
In vector terms: rows qi,kjq_i,k_jqi​,kj​ of Q,KQ,KQ,K give Sij=(qi⋅kj)/dkS_{ij}=(q_i\cdot k_j)/\sqrt{d_k}Sij​=(qi​⋅kj​)/dk​​, and output row iii is ∑jAijVj\sum_j A_{ij}V_j∑j​Aij​Vj​. Flash computes this map with less memory movement; sparse masks logits before softmax so only a subset of columns gets mass for each row.

Formulas in plain words

With sequence length NNN, each row of Q,K,VQ,K,VQ,K,V is one token’s vector (dkd_kdk​ for queries/keys, dvd_vdv​ for values in the usual setup). First form the scaled similarity / logits matrix
S=QKTdkS=\frac{QK^T}{\sqrt{d_k}}S=dk​​QKT​
SijS_{ij}Sij​ scores “query iii vs key jjj”. The 1/dk1/\sqrt{d_k}1/dk​​ scaling is the scaled dot-product trick: inner products grow with dimension, and softmax would become very peaked without this scale.
Apply softmax across keys (columns) for each query row to get attention weights AAA:
A=softmax(S)=softmax(QKTdk),∑j=1NAij=1A=\mathrm{softmax}(S)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right),\quad \sum_{j=1}^{N} A_{ij}=1A=softmax(S)=softmax(dk​​QKT​),j=1∑N​Aij​=1
Finally mix value rows with those weights. Output row iii is a convex combination of VVV rows:
outputi=∑j=1NAijVj⇔Attention(Q,K,V)=AV\mathrm{output}_i=\sum_{j=1}^{N} A_{ij}V_j \quad\Leftrightarrow\quad \mathrm{Attention}(Q,K,V)=AVoutputi​=j=1∑N​Aij​Vj​⇔Attention(Q,K,V)=AV
Combine the two steps in one line:
Attention(Q,K,V)=softmax(QKTdk)V\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk​​QKT​)V
The backbone cost is materializing QKTQK^TQKT (or SSS): shape N×NN \times NN×N. The matmul is roughly O(N2dk)O(N^2 d_k)O(N2dk​), and the tensor has N2N^2N2 entries—double NNN and entries scale ~4×. Row softmax and the AVAVAV product add more work tied to that same N2N^2N2 structure. Batch size and multiple heads multiply real VRAM/FLOPs, but the “why long context hurts” story is this quadratic grid.
Sparse attention encodes “don’t attend to every jjj” in the logits before softmax. With E=QKT/dkE=QK^T/\sqrt{d_k}E=QKT/dk​​ and an allowed key set SiS_iSi​ for query iii, set Eij=−∞E_{ij}=-\inftyEij​=−∞ when j∉Sij\notin S_ij∈/Si​ (in code, a large negative finite value is common).
softmax(Ei:)j≈0(j∉Si)\mathrm{softmax}(E_{i:})_j \approx 0 \quad (j\notin S_i)softmax(Ei:​)j​≈0(j∈/Si​)
Then AijA_{ij}Aij​ is effectively nonzero only on SiS_iSi​, and you still write output=AV\mathrm{output}=AVoutput=AV, but compute/memory focus on allowed pairs.
FlashAttention targets the same attention outputs as softmax(QKT/dk)V\mathrm{softmax}(QK^T/\sqrt{d_k})Vsoftmax(QKT/dk​​)V but avoids keeping the full N×NN\times NN×N resident in slow memory: it tiles the computation, accumulates partial softmax statistics on fast on-chip memory, and fuses steps. Same mapping; better data movement.

Attention Optimization: FlashAttention and Sparse Attention

1. Why it gets heavy Concept: Attention materializes something like an N×NN \times NN×N score table; every extra step multiplies pairwise work. Intuition: handshakes—cost scales like N2N^2N2 when every token can attend to every token. In practice: long PDFs / prompts / batches OOM or crawl unless you optimize.
Remember: the enemy is the cost curve, not the idea of attention itself.
2. FlashAttention Cut HBM traffic by tiling: keep hot blocks in fast SRAM, fuse steps in kernels, still compute standard softmax attention. Same math, better schedule on real hardware.
Formula target (unchanged): softmax(QKTdk)V\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)Vsoftmax(dk​​QKT​)V — only order / tiling of partials changes, not the mapping (up to normal numerical details).
Remember: default story is exact dense attention, faster, not a different objective.
3. Sparse attention Fix a pattern—local neighbors + a few globals is common—so each query touches ∣Si∣≪N|S_i| \ll N∣Si​∣≪N keys. Pros: less memory & compute. Cons: miss rare long-range links if the pattern skips them.
Typical logits: with E=QKT/dkE=QK^T/\sqrt{d_k}E=QKT/dk​​, set Eij=−∞E_{ij}=-\inftyEij​=−∞ for masked (i,j)(i,j)(i,j) so softmax\mathrm{softmax}softmax mass stays on allowed columns.
Remember: sparse is design—what you don’t attend to matters.
4. Key difference Flash targets identical outputs, faster. Sparse redesigns connectivity; outputs need measurement. Playbook: enable Flash, profile; if still tight, explore sparse / chunking / RAG with quality checks.

Why it matters

Cost & scale Attention is often the GPU-hour and VRAM peak driver. Flash / fused kernels / sparser masks let the same weights serve longer sequences or larger batches without buying more chips first.Long context Products that swallow whole docs rely on efficient attention to push usable context windows. Latency and unit economics follow directly.Quality vs speed Each edge you drop saves work but may drop evidence. Benchmark tasks with long-range needs; widen globals or add retrieval when patterns are thin.Connects to Ch01–03 Q,K,VQ,K,VQ,K,V and softmax logic stay. Here we only change execution and who attends to whom.

How it is used

Training Flip Flash / SDPA / fused attention when your framework supports it—watch peak VRAM and step time. Same GPU, slightly larger batch or seq is often the first win.Inference KV cache reuses K,VK,VK,V; kernel efficiency drives TTFT and per-token decode. Users feel every millisecond.Sparse patterns Docs: local + paragraph anchors. Code: widen windows or add globals for scope/braces. Let the data shape the mask.
OOM triage
① length · batch · dtype →
② Flash truly on? →
③ still stuck? try chunking, sparsity, or RAG after you confirm long-range is truly required.

Summary

In one breath — FlashAttention runs the same softmax attention faster by tiling work to match GPU memory hierarchy; sparse attention cuts how many keys each query sees using a fixed pattern, saving compute and memory.
Why length hurts — For length NNN, attention builds about N×NN \times NN×N scores. Costs ramp up fast; long prompts and big batches are where OOM and slowdown show up first.
Flash in plain English — Large tensors live in slow HBM; Flash chunks math so small tiles stay in fast on-chip memory and fewer round-trips are needed. The outputs match standard attention; the win is implementation.
Sparse in plain English — You don’t let every token attend to every token. Common mixes are nearby windows plus a few global anchor tokens. You trade full connectivity for efficiency; rare long-range facts need a careful pattern or other tools (e.g. RAG).
How people use this — Enable Flash first for safe speedups. Add or tune sparse layouts when you must squeeze further and can measure quality. Ch01–03 math stays; this is how we execute it efficiently.
Key equation (review) — stages: logits S=QKT/dkS=QK^T/\sqrt{d_k}S=QKT/dk​​, row-softmax gives AAA, output AVAVAV.
Attention(Q,K,V)=softmax(QKTdk)V=AV\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V=AVAttention(Q,K,V)=softmax(dk​​QKT​)V=AV
Row iii is ∑jAijVj\sum_j A_{ij}V_j∑j​Aij​Vj​: you mix value vectors with attention weights. Both QKTQK^TQKT and AAA touch an N×NN \times NN×N grid—this is why length hurts.
Chapter core:
• Pain grows with ~N2N^2N2 score work.
• Flash = kernels / tiling, same attention.
• Sparse = patterns, verify quality.

How to approach problems

Summary — FlashAttention speeds up the same softmax attention via IO-aware kernels, while sparse attention restricts which keys each query attends to (windows, global tokens, patterns). As NNN grows, the N2N^2N2 score grid becomes the pain point; in practice you pick Flash vs sparsity patterns to balance memory, latency, and quality.
Memory aid: Flash = same map, faster path; sparse = fewer edges, design matters. On OOM, verify Flash on, then length / batch / dtype.
Field checklist:
• Need identical dense math, just faster → Flash first.
• Need cheaper attention → sparse pattern + eval.
• Rare long-range clues → revisit pattern + RAG.
  • TypeConcept
  • What to enterBest matching option → integer 1–3
  • TypeT/F
  • What to enterTrue/false → 1 or 0
  • TypeScenario
  • What to enterBest option for the story → 1–3
  • TypeVote sum
  • What to enterCount of 1s in the binary vector → integer
  • TypeAggregate
  • What to enterSum of given numbers → integer
  • TypeConfig / grid
  • What to enterCells in an n×nn\times nn×n grid = n2n^2n2 → integer
  • TypeEnsemble / trade-off
  • What to enterBest statement → 1–3
TypeWhat to enter
ConceptBest matching option → integer 1–3
T/FTrue/false → 1 or 0
ScenarioBest option for the story → 1–3
Vote sumCount of 1s in the binary vector → integer
AggregateSum of given numbers → integer
Config / gridCells in an n×nn\times nn×n grid = n2n^2n2 → integer
Ensemble / trade-offBest statement → 1–3
Example (concept)
"Closest to FlashAttention’s goal?
① Only approximate softmax
② IO-aware faster exact softmax attention
③ Remove attention entirely"
Still the same attention, faster to compute → 2

"Core idea of sparse attention?
① Always attend to all keys
② Restrict keys with a pattern
③ Use only FFN"
Limits connectivity with a pattern → 2

Example (T/F)
"Tiled GPU kernels can help exploit on-chip memory. Enter 1 if true, 0 if false."
Matches the Flash story → 1

Example (scenario)
"Policy: need mathematically identical attention output.
① Exact Flash path
② Delete all query-key edges
③ 10× LR only"
Keep the exact mapping → 1

Example (vote sum)
"votes = [1,1,1,1,0] — how many 1s?"
Sum = 4 → 4

Example (aggregate)
"values = [7,8,7] sum?"
22 → 22

Example (grid)
"Square grid side length 12 — how many cells?"
122=14412^2=144122=144 → 144

Example (ensemble)
"Common sparse-attention risk?
① Always higher accuracy
② May lose rare long-range deps
③ No GPU needed"
Sparsity can miss far evidence → 2