Head grouping

Each query head still computes its own attention. Queries in the same group share K and V, so we cache fewer copies — that is the whole trick.

KV cache growth as tokens stream in

Each square is one cached K/V vector. Every new token writes n_kv squares per layer. After the same number of tokens, MHA's row is much longer than GQA's, and GQA's much longer than MQA's — that gap is the memory savings.

KV cache memory at this sequence length

cache = 2 × n_layers × n_kv_heads × head_dim × seq_len × bytes
The factor of 2 covers both K and V. Q is recomputed every step, so it does not count. Reducing n_kv_heads is the only knob GQA/MQA touch.

Real-world models

MHA classic
n_kv = n_q
GPT-2, GPT-3, BERT, original Transformer
GQA sweet spot
1 < n_kv < n_q
Llama 2-70B (8), Llama 3 (8), Mistral 7B (8), Mixtral
MQA most aggressive
n_kv = 1
PaLM, Falcon, StarCoder, Gemini Nano

The tradeoff

✓ Smaller KV cache
Lower memory → larger batch → higher inference throughput. The whole reason GQA exists.
✓ Faster decoding
Less data to load from HBM each token. Memory bandwidth is the bottleneck in autoregressive decode.
✗ Quality drop
MQA loses noticeable quality. GQA recovers most of MHA's quality while keeping the memory win — that is why it became the default.