Ch.04
Attention Optimization: FlashAttention and Sparse Attention
Dense fills the full grid—expressive but heavy at long . Flash tiles the same math to cut HBM traffic. Sparse drops edges via local + global patterns.
Dense vs Flash vs Sparse
Dense
Dense: Every (i,j) pair scored → expressivity ↑, high → cost ↑.
Flash
Tile: Same softmax, tile-wise → fewer HBM trips, speed ↑.
Sparse
Sparse: Window + globals → fewer active products; long-range by design.
Three choices at a glance
① Bottleneck (, OOM) → ② Flash / sparse → ③ validate metrics
① Flash: IO-aware path for identical softmax.
② Sparse: cut edges—long range in the mask.
③ Dense: budget always.
④ Ops: OOM → Flash·batch·dtype / quality → globals·RAG.
Self-attention is a core trick for mixing tokens, but as length grows, the score grid (~) makes memory and FLOPs spike—long inputs or big batches slam into OOM and latency. The issue is not “attention is wrong” but cost grows too fast with length.
FlashAttention keeps the same softmax attention math and wins on tiling + memory hierarchy (fewer slow HBM round-trips, more work in fast on-chip memory). Sparse attention shrinks the wiring: each query attends only to a pattern (e.g. local window + a few globals). You’ll see when to pick Flash (same mapping, faster run) vs sparse (fewer edges, design + validation) and how both show up in training and serving.
Core formula (scaled dot-product / softmax attention). One head computes
Build scaled logits . Entry is a raw “how much should position look at key ?” score; dividing by keeps dot products from growing too large as the head dimension grows, so softmax does not collapse to one-hot too harshly. For each row , row-softmax gives weights with , and those weights blend the rows of to form output row .
In vector terms: rows of give , and output row is . Flash computes this map with less memory movement; sparse masks logits before softmax so only a subset of columns gets mass for each row.
Formulas in plain words
With sequence length , each row of is one token’s vector ( for queries/keys, for values in the usual setup). First form the scaled similarity / logits matrix
scores “query vs key ”. The scaling is the scaled dot-product trick: inner products grow with dimension, and softmax would become very peaked without this scale.
Apply softmax across keys (columns) for each query row to get attention weights :
Finally mix value rows with those weights. Output row is a convex combination of rows:
Combine the two steps in one line:
Sparse attention encodes “don’t attend to every ” in the logits before softmax. With and an allowed key set for query , set when (in code, a large negative finite value is common).
Then is effectively nonzero only on , and you still write , but compute/memory focus on allowed pairs.
Attention Optimization: FlashAttention and Sparse Attention
1. Why it gets heavy Concept: Attention materializes something like an score table; every extra step multiplies pairwise work. Intuition: handshakes—cost scales like when every token can attend to every token. In practice: long PDFs / prompts / batches OOM or crawl unless you optimize.
Remember: the enemy is the cost curve, not the idea of attention itself.
2. FlashAttention Cut HBM traffic by tiling: keep hot blocks in fast SRAM, fuse steps in kernels, still compute standard softmax attention. Same math, better schedule on real hardware.
Formula target (unchanged): — only order / tiling of partials changes, not the mapping (up to normal numerical details).
Remember: default story is exact dense attention, faster, not a different objective.
3. Sparse attention Fix a pattern—local neighbors + a few globals is common—so each query touches keys. Pros: less memory & compute. Cons: miss rare long-range links if the pattern skips them.
Typical logits: with , set for masked so mass stays on allowed columns.
Remember: sparse is design—what you don’t attend to matters.
Why it matters
Cost & scale Attention is often the GPU-hour and VRAM peak driver. Flash / fused kernels / sparser masks let the same weights serve longer sequences or larger batches without buying more chips first.Long context Products that swallow whole docs rely on efficient attention to push usable context windows. Latency and unit economics follow directly.Quality vs speed Each edge you drop saves work but may drop evidence. Benchmark tasks with long-range needs; widen globals or add retrieval when patterns are thin.Connects to Ch01–03 and softmax logic stay. Here we only change execution and who attends to whom.
How it is used
Training Flip Flash / SDPA / fused attention when your framework supports it—watch peak VRAM and step time. Same GPU, slightly larger batch or seq is often the first win.Inference KV cache reuses ; kernel efficiency drives TTFT and per-token decode. Users feel every millisecond.Sparse patterns Docs: local + paragraph anchors. Code: widen windows or add globals for scope/braces. Let the data shape the mask.
OOM triage
① length · batch · dtype →
② Flash truly on? →
③ still stuck? try chunking, sparsity, or RAG after you confirm long-range is truly required.
Summary
In one breath — FlashAttention runs the same softmax attention faster by tiling work to match GPU memory hierarchy; sparse attention cuts how many keys each query sees using a fixed pattern, saving compute and memory.
Why length hurts — For length , attention builds about scores. Costs ramp up fast; long prompts and big batches are where OOM and slowdown show up first.
Flash in plain English — Large tensors live in slow HBM; Flash chunks math so small tiles stay in fast on-chip memory and fewer round-trips are needed. The outputs match standard attention; the win is implementation.
Sparse in plain English — You don’t let every token attend to every token. Common mixes are nearby windows plus a few global anchor tokens. You trade full connectivity for efficiency; rare long-range facts need a careful pattern or other tools (e.g. RAG).
How people use this — Enable Flash first for safe speedups. Add or tune sparse layouts when you must squeeze further and can measure quality. Ch01–03 math stays; this is how we execute it efficiently.
Key equation (review) — stages: logits , row-softmax gives , output .
Row is : you mix value vectors with attention weights. Both and touch an grid—this is why length hurts.
Chapter core:
• Pain grows with ~ score work.
• Flash = kernels / tiling, same attention.
• Sparse = patterns, verify quality.
How to approach problems
Summary — FlashAttention speeds up the same softmax attention via IO-aware kernels, while sparse attention restricts which keys each query attends to (windows, global tokens, patterns). As grows, the score grid becomes the pain point; in practice you pick Flash vs sparsity patterns to balance memory, latency, and quality.
Memory aid: Flash = same map, faster path; sparse = fewer edges, design matters. On OOM, verify Flash on, then length / batch / dtype.
Field checklist:
• Need identical dense math, just faster → Flash first.
• Need cheaper attention → sparse pattern + eval.
• Rare long-range clues → revisit pattern + RAG.
- TypeConcept
- What to enterBest matching option → integer 1–3
- TypeT/F
- What to enterTrue/false → 1 or 0
- TypeScenario
- What to enterBest option for the story → 1–3
- TypeVote sum
- What to enterCount of 1s in the binary vector → integer
- TypeAggregate
- What to enterSum of given numbers → integer
- TypeConfig / grid
- What to enterCells in an grid = → integer
- TypeEnsemble / trade-off
- What to enterBest statement → 1–3
| Type | What to enter |
|---|---|
| Concept | Best matching option → integer 1–3 |
| T/F | True/false → 1 or 0 |
| Scenario | Best option for the story → 1–3 |
| Vote sum | Count of 1s in the binary vector → integer |
| Aggregate | Sum of given numbers → integer |
| Config / grid | Cells in an grid = → integer |
| Ensemble / trade-off | Best statement → 1–3 |
Example (concept)
"Closest to FlashAttention’s goal?
① Only approximate softmax
② IO-aware faster exact softmax attention
③ Remove attention entirely"
Still the same attention, faster to compute → 2
"Core idea of sparse attention?
① Always attend to all keys
② Restrict keys with a pattern
③ Use only FFN"
Limits connectivity with a pattern → 2
Example (T/F)
"Tiled GPU kernels can help exploit on-chip memory. Enter 1 if true, 0 if false."
Matches the Flash story → 1
Example (scenario)
"Policy: need mathematically identical attention output.
① Exact Flash path
② Delete all query-key edges
③ 10× LR only"
Keep the exact mapping → 1
Example (vote sum)
"votes = [1,1,1,1,0] — how many 1s?"
Sum = 4 → 4
Example (aggregate)
"values = [7,8,7] sum?"
22 → 22
Example (grid)
"Square grid side length 12 — how many cells?"
→ 144
Example (ensemble)
"Common sparse-attention risk?
① Always higher accuracy
② May lose rare long-range deps
③ No GPU needed"
Sparsity can miss far evidence → 2