Standard attention launches 4 separate CUDA kernels per pass and round-trips an N×N matrix through HBM. FlashAttention fuses them into one kernel: stream K, V tiles into SRAM, run an online softmax, and only ever write the final output to HBM. The backward pass uses the same trick — instead of loading the saved N×N matrix, it recomputes attention from Q, K and the tiny saved (m, ℓ) stats. Result: O(N·d) HBM traffic on both passes, with no approximation. (EE-508 Transformers — Part 3)
Each box is one CUDA kernel launch. Vertical arrows are HBM round-trips between launches — every one costs ~N² of bandwidth on standard attention. FlashAttention collapses them all into one launch with no intermediate HBM traffic.
For each Q tile i, three small vectors live in SRAM and are updated as each K/V tile j streams through. After the last tile, the final Oi is written to HBM — and (mi, ℓi) are saved for the backward pass (just O(N) bytes, vs O(N²) for the full matrix).
Standard pays N² on forward and again on backward (it loaded the saved A). Flash pays O(N·d) on both.
The bottleneck isn't FLOPs — it's HBM. An A100 has 312 TFLOPs of FP16 compute but only 2 TB/s of HBM bandwidth. Standard attention reads/writes the N×N matrix multiple times: at N=4096, d=64, that's ~250 MB of traffic per layer per pass. The compute units sit idle waiting on memory.
SRAM is 10× faster per byte. Each streaming multiprocessor has ~192 KB of L1/SRAM running at ~19 TB/s. If you can keep the working set there, you decouple from HBM bandwidth.
Forward unlock — kernel fusion. Standard attention is 4 separate kernel launches: matmul → softmax → mask → matmul, each writing its intermediate to HBM and the next reading it back. The N×N attention matrix gets materialized. FlashAttention's online softmax rewrites the normalization as an incremental update so all four ops fit in one kernel — read Q, K, V tile-by-tile, write only O. No approximation, just smarter math.
Backward unlock — recomputation. The standard backward needs the attention matrix A again to differentiate through softmax. Storing all of A from the forward pass costs O(N²) HBM. FlashAttention saves only the normalization stats (m, ℓ) — that's O(N) bytes — and recomputes A in blocks during the backward kernel. A small extra FLOP cost buys 10–20× memory savings, and because everything stays in SRAM, the backward is also bandwidth-bound-free.
The Br trap. FlashAttention's K and V get re-read once for every Q-block — N/Br times total, on both passes. If you pick Br too small, the re-read overhead exceeds the savings from skipping the N×N matrix and Flash loses to standard. Real implementations set Br ≈ M / (4·d), where M is the SRAM budget — typically 64–128 rows for d=64 on an A100. Try the slider: with the default N=16, d=8, Flash beats Standard at Br=8 but loses at Br=2.
FlashAttention-2 went further: fully fused softmax + dropout + masking + matmul in one CUDA kernel, plus better parallelism across batch / heads / blocks. Net: 2–4× faster end-to-end attention, used by every modern LLM training pipeline. (Companion: activation checkpointing generalizes this same recompute-vs-store trade-off to whole networks.)