Online Softmax

The streaming math trick behind FlashAttention

Step 0: The Setup

Try a preset:

Deep Dive: Why naive per-block softmax fails, and the math

Naive idea: softmax each block separately

Each block sums to 1.0 individually, but globally the answer is wrong — relative magnitudes between blocks are lost. Block 2 has a larger max, but naive softmax has no way to know that.

Naive:
Online:
Batch:

Online matches Batch exactly. Naive doesn't.

The streaming update rule

For each new block b with values x:

m_new = max(m_old, max(x))
scale = exp(m_old − m_new)   // ≤ 1
ℓ_new = ℓ_old · scale  +  Σ exp(x − m_new)
// every saved exp value gets multiplied by scale too

The key identity: exp(x − m_old) · exp(m_old − m_new) = exp(x − m_new). Old exp values can be retroactively re-based to a new max by multiplication.

FlashAttention connection: Replace each "block" with a tile of K/V loaded into SRAM. The same streaming rule produces exact softmax(QKᵀ)V without ever materializing the N×N attention matrix in HBM. Memory drops from O(N²) to O(N).