EE508-style visual aid • Multi-head attention

See how one sequence becomes many heads, then comes back together.

This visualizer follows the slide convention used in class: a sequence with shape 3 × 512 is projected into Q, K, V, split across multiple heads, processed independently, and then concatenated back to the model dimension. The right-hand math panel uses a tiny toy example so you can see the score calculation without drowning in 512 numbers.

X input embeddings: 3 × 512
Wq, Wk, Wv learned projections
4 heads → 128 dims per head
Attention = softmax(QKᵀ / √dₖ)V
Sequence length
3
Three tokens in the slide-style example.
Model width
512
The full embedding dimension before head-splitting.
Heads
4
Each head sees its own learned projection.
Per-head dim
128
512 ÷ 4 = 128 dimensions per head.

1) From tokens to Q, K, V

The input stays as a 3 × 512 matrix. Multi-head attention does not split the input first; it creates learned projections and then splits the projected space into heads.

Input
Project
Split heads
Attend
Concat
X 3 × 512
Linear projections learned during training

Q, K, V come from the same X

Each projection is a learned matrix multiplication. The model learns Wq, Wk, and Wv so that some directions become better for matching (Q/K) and some become better for carrying information (V).

Q shape
3 × 512
K shape
3 × 512
V shape
3 × 512
Wq
learned
Wk
learned
Wv
learned

2) Heads work in parallel

In the 4-head slide example, 512 dimensions are split across heads as 128 per head. Each head learns to focus on different relationships, then the outputs are concatenated back.

Why multiple heads? different views of the same sentence

One head may look for local grammar; another may look for long-range meaning

Multi-head attention gives the model several smaller attention spaces instead of one large space. This makes it easier to capture different patterns at once.

3) Tiny attention math

This panel compresses the idea into a toy 4-dimensional example so the score path is visible: QKᵀ → scaled scores → softmax → weighted sum of V.

Selected head + query token
Attention scores
Raw
Scaled
Weights
Final weighted sum