Online Softmax
The streaming math trick behind FlashAttention
Step 0: The Setup
Try a preset:
Deep Dive: Why naive per-block softmax fails, and the math
Naive idea: softmax each block separately
Each block sums to 1.0 individually, but globally the answer is wrong — relative magnitudes between blocks are lost. Block 2 has a larger max, but naive softmax has no way to know that.
Naive:
Online:
Batch:
Online matches Batch exactly. Naive doesn't.
The streaming update rule
For each new block b with values x:
m_new = max(m_old, max(x))
scale = exp(m_old − m_new) // ≤ 1
ℓ_new = ℓ_old · scale + Σ exp(x − m_new)
// every saved exp value gets multiplied by scale too
The key identity: exp(x − m_old) · exp(m_old − m_new) = exp(x − m_new). Old exp values can be retroactively re-based to a new max by multiplication.
FlashAttention connection: Replace each "block" with a tile of K/V loaded into SRAM. The same streaming rule produces exact softmax(QKᵀ)V without ever materializing the N×N attention matrix in HBM. Memory drops from O(N²) to O(N).