Step-by-step mathematical journey from Multi-Head Attention through Grouped Query Attention to Multi-Head Latent Attention, with deep dive into KV caching and memory optimization
π The Evolution Timeline
2017
Multi-Head Attention (MHA)
Original "Attention is All You Need"
2019
Multi-Query Attention (MQA)
Single K,V shared across heads
2023
Grouped Query Attention (GQA)
Balance between MHA and MQA
2024
Multi-Head Latent Attention (MLA)
DeepSeek's compression innovation
π― Core Problem: As models scale and context lengths grow, the KV cache becomes the primary memory bottleneck during inference
ποΈ Foundation: Understanding KV Caching
π CRITICAL UNDERSTANDING: KV caching is a universal optimization that applies to ALL attention mechanisms - MHA, MQA, GQA, and MLA. The difference between mechanisms is WHAT gets cached, not WHETHER caching is used.
β Why KV Caching? In autoregressive generation, we recompute the same K,V values for previous tokens at every step. Caching eliminates this redundancy through a memory-for-compute trade-off.
π Step 1: The Autoregressive Problem
Without KV Cache (Inefficient):
Step 1: "The" β Compute Q,K,V for position 0
Step 2: "The cat" β Recompute Q,K,V for positions 0,1
Step 3: "The cat sat" β Recompute Q,K,V for positions 0,1,2
... Problem: Massive redundant computation!
π Key Insight:
Each head has its own unique K_i and V_i matrices
Total KV Cache = 2 Γ seq_len Γ h Γ d_head Γ sizeof(float16)
= 2 Γ seq_len Γ d_model Γ 2 bytes
πΆ Multi-Query Attention (MQA) - First Optimization
π¨ Radical Idea: What if all heads shared the same K and V?
π Mathematical Definition
MQA with h heads:
Q_i = Input Γ W^Q_i β β^(seq_len Γ d_head) # for i = 1..h (unique)
K = Input Γ W^K β β^(seq_len Γ d_head) # SINGLE shared K
V = Input Γ W^V β β^(seq_len Γ d_head) # SINGLE shared V
Head_i = Attention(Q_i, K, V) # Same K,V for all heads
Output = Concat(Head_1, ..., Head_h) Γ W^O
π΅ Grouped Query Attention (GQA) - The Balanced Approach
π‘ Insight: MQA reduces memory but hurts quality. Can we find a middle ground?
π Mathematical Definition
GQA with h heads, g groups:
Group size = h / g
Q_i = Input Γ W^Q_i β β^(seq_len Γ d_head) # for i = 1..h (unique)
K_j = Input Γ W^K_j β β^(seq_len Γ d_head) # for j = 1..g (g groups)
V_j = Input Γ W^V_j β β^(seq_len Γ d_head) # for j = 1..g (g groups)
Token 2: "cat"
Inputβ = [0.1, 0.7, 0.2, -0.3, 0.6, 0.4, -0.2, 0.5]
CΒ²α΄·β±½ = Inputβ Γ W_compress = [0.35, 0.47, 0.28] β Cache this
Cache = [CΒΉα΄·β±½, CΒ²α΄·β±½] = 6 numbers total vs MHA: Would cache 32 numbers for 2 tokens
When Attention is Needed:
Kβ = CΒΉα΄·β±½ Γ W_K_decomp = [reconstructed K for "The"]
Kβ = CΒ²α΄·β±½ Γ W_K_decomp = [reconstructed K for "cat"]
Vβ = CΒΉα΄·β±½ Γ W_V_decomp = [reconstructed V for "The"]
Vβ = CΒ²α΄·β±½ Γ W_V_decomp = [reconstructed V for "cat"]
Then: Normal attention with reconstructed K, V matrices
π― The Magic: 6 numbers (MLA) vs 32 numbers (MHA) for same sequence - 5.3Γ memory reduction with learned optimal compression!