πŸ”„ Evolution of Attention Mechanisms

Step-by-step mathematical journey from Multi-Head Attention through Grouped Query Attention to Multi-Head Latent Attention, with deep dive into KV caching and memory optimization

πŸ“ˆ The Evolution Timeline

2017
Multi-Head Attention (MHA)
Original "Attention is All You Need"
2019
Multi-Query Attention (MQA)
Single K,V shared across heads
2023
Grouped Query Attention (GQA)
Balance between MHA and MQA
2024
Multi-Head Latent Attention (MLA)
DeepSeek's compression innovation
🎯 Core Problem: As models scale and context lengths grow, the KV cache becomes the primary memory bottleneck during inference

πŸ—οΈ Foundation: Understanding KV Caching

πŸ”‘ CRITICAL UNDERSTANDING: KV caching is a universal optimization that applies to ALL attention mechanisms - MHA, MQA, GQA, and MLA. The difference between mechanisms is WHAT gets cached, not WHETHER caching is used.
❓ Why KV Caching? In autoregressive generation, we recompute the same K,V values for previous tokens at every step. Caching eliminates this redundancy through a memory-for-compute trade-off.

πŸ” Step 1: The Autoregressive Problem

Without KV Cache (Inefficient):
Step 1: "The" β†’ Compute Q,K,V for position 0
Step 2: "The cat" β†’ Recompute Q,K,V for positions 0,1
Step 3: "The cat sat" β†’ Recompute Q,K,V for positions 0,1,2
...
Problem: Massive redundant computation!
With KV Cache (Efficient):
Step 1: "The" β†’ Compute Qβ‚€,Kβ‚€,Vβ‚€, cache Kβ‚€,Vβ‚€
Step 2: "cat" β†’ Compute Q₁,K₁,V₁, cache K₁,V₁, reuse Kβ‚€,Vβ‚€
Step 3: "sat" β†’ Compute Qβ‚‚,Kβ‚‚,Vβ‚‚, cache Kβ‚‚,Vβ‚‚, reuse Kβ‚€,Vβ‚€,K₁,V₁
...
Trade-off: Use MORE memory to avoid redundant computation!
πŸ”‘ Key Insight: KV caching is a memory-for-compute trade-off:
β€’ ❌ Increases memory usage (store all previous K,V)
β€’ βœ… Reduces computation (avoid recomputing previous tokens)
β€’ ⚑ Dramatically speeds up inference (especially for long sequences)

πŸ“ Step 2: Mathematical Foundation of KV Cache

Standard Attention Formula:
Attention(Q, K, V) = softmax(Q Γ— K^T / √d_k) Γ— V

With KV Cache:
K_cache = [Kβ‚€, K₁, Kβ‚‚, ..., K_{t-1}] # Previous tokens
V_cache = [Vβ‚€, V₁, Vβ‚‚, ..., V_{t-1}] # Previous tokens
K_new = [K_cache, K_t] # Append current token
V_new = [V_cache, V_t] # Append current token
Output = Attention(Q_t, K_new, V_new)

πŸ”· Multi-Head Attention (MHA) - The Foundation

πŸ“ Mathematical Definition

MHA with h heads:
Q_i = Input Γ— W^Q_i ∈ ℝ^(seq_len Γ— d_head) # for i = 1..h
K_i = Input Γ— W^K_i ∈ ℝ^(seq_len Γ— d_head) # for i = 1..h
V_i = Input Γ— W^V_i ∈ ℝ^(seq_len Γ— d_head) # for i = 1..h

Head_i = Attention(Q_i, K_i, V_i)
Output = Concat(Head_1, ..., Head_h) Γ— W^O
πŸ”‘ Key Insight:
Each head has its own unique K_i and V_i matrices
Total KV Cache = 2 Γ— seq_len Γ— h Γ— d_head Γ— sizeof(float16)
= 2 Γ— seq_len Γ— d_model Γ— 2 bytes

πŸ”Ά Multi-Query Attention (MQA) - First Optimization

🚨 Radical Idea: What if all heads shared the same K and V?

πŸ“ Mathematical Definition

MQA with h heads:
Q_i = Input Γ— W^Q_i ∈ ℝ^(seq_len Γ— d_head) # for i = 1..h (unique)
K = Input Γ— W^K ∈ ℝ^(seq_len Γ— d_head) # SINGLE shared K
V = Input Γ— W^V ∈ ℝ^(seq_len Γ— d_head) # SINGLE shared V

Head_i = Attention(Q_i, K, V) # Same K,V for all heads
Output = Concat(Head_1, ..., Head_h) Γ— W^O

πŸ”΅ Grouped Query Attention (GQA) - The Balanced Approach

πŸ’‘ Insight: MQA reduces memory but hurts quality. Can we find a middle ground?

πŸ“ Mathematical Definition

GQA with h heads, g groups:
Group size = h / g
Q_i = Input Γ— W^Q_i ∈ ℝ^(seq_len Γ— d_head) # for i = 1..h (unique)
K_j = Input Γ— W^K_j ∈ ℝ^(seq_len Γ— d_head) # for j = 1..g (g groups)
V_j = Input Γ— W^V_j ∈ ℝ^(seq_len Γ— d_head) # for j = 1..g (g groups)

Head_i uses K_{⌊i/group_sizeβŒ‹}, V_{⌊i/group_sizeβŒ‹}
Output = Concat(Head_1, ..., Head_h) Γ— W^O

πŸ”Ά Multi-Head Latent Attention (MLA) - DeepSeek's Innovation

πŸš€ Revolutionary Idea: Instead of reducing heads, compress the K,V representations themselves!

πŸ—οΈ Complete MLA Architecture

MLA Step-by-Step Process:

1. Joint Compression (NEW):
C^KV = Input @ W_compress ∈ ℝ^(seq_len Γ— d_compress)
where d_compress β‰ͺ n_heads Γ— d_head

2. Cache Compressed Form:
KV_cache = [C^KV] (much smaller!)

3. Query Processing (same as MHA):
Q = Input @ W_Q, reshape to heads

4. Decompression (NEW):
K = C^KV @ W_K_decomp ∈ ℝ^(seq_len Γ— d_model)
V = C^KV @ W_V_decomp ∈ ℝ^(seq_len Γ— d_model)

5. Standard Attention:
Attention(Q, K, V) as usual

βš–οΈ Complete Comparison: MHA vs GQA vs MLA

⚑ Inference: How It Works in Practice

πŸ”„ Step-by-Step Example: Generating "The cat sat"

Setup: d_model=8, n_heads=4, d_compress=3
Memory Comparison: MHA would cache 4Γ—2=8 matrices per token, MLA caches 1 compressed vector

Token 1: "The"
Input₁ = [0.2, 0.5, -0.1, 0.8, 0.3, -0.4, 0.7, 0.1]
CΒΉα΄·β±½ = Input₁ Γ— W_compress = [0.42, 0.38, 0.31] ← Cache this (3 numbers)
vs MHA: Would cache 16 numbers for 4 heads Γ— 2 (K,V) Γ— 2 dimensions

Token 2: "cat"
Inputβ‚‚ = [0.1, 0.7, 0.2, -0.3, 0.6, 0.4, -0.2, 0.5]
CΒ²α΄·β±½ = Inputβ‚‚ Γ— W_compress = [0.35, 0.47, 0.28] ← Cache this
Cache = [CΒΉα΄·β±½, CΒ²α΄·β±½] = 6 numbers total
vs MHA: Would cache 32 numbers for 2 tokens

When Attention is Needed:
K₁ = CΒΉα΄·β±½ Γ— W_K_decomp = [reconstructed K for "The"]
Kβ‚‚ = CΒ²α΄·β±½ Γ— W_K_decomp = [reconstructed K for "cat"]
V₁ = CΒΉα΄·β±½ Γ— W_V_decomp = [reconstructed V for "The"]
Vβ‚‚ = CΒ²α΄·β±½ Γ— W_V_decomp = [reconstructed V for "cat"]
Then: Normal attention with reconstructed K, V matrices
🎯 The Magic: 6 numbers (MLA) vs 32 numbers (MHA) for same sequence - 5.3Γ— memory reduction with learned optimal compression!