The Math: Formal Definitions and Worked Examples
This section formalises Grouped Query Attention and Sliding Window Attention with precise equations and numerical examples you can verify by hand.
Prerequisites
Before reading this section, review:
Part 1: Grouped Query Attention
Definition: Standard Multi-Head Attention (MHA)
For a sequence of length n and d embedding dimensions, split into n_heads heads:
Inputs:
- Query matrix: Q ∈ ℝ^(n × d)
- Key matrix: K ∈ ℝ^(n × d)
- Value matrix: V ∈ ℝ^(n × d)
- d_head = d / n_heads (dimension per head)
Computation for head i:
Q_i = Q W_i^Q ∈ ℝ^(n × d_head) (project to head i)
K_i = K W_i^K ∈ ℝ^(n × d_head)
V_i = V W_i^V ∈ ℝ^(n × d_head)
Scores_i = Q_i K_i^T / √d_head ∈ ℝ^(n × n)
Weights_i = softmax(Scores_i) ∈ ℝ^(n × n)
Output_i = Weights_i V_i ∈ ℝ^(n × d_head)
Concatenate all heads:
Output = Concat(Output_1, ..., Output_n_heads) ∈ ℝ^(n × d)
Output = Output W_O (project back)
KV cache during inference (autoregressive generation):
At step t (generating token t), we need K and V from all previous steps 1 to t-1. Cache sizes:
KV_cache = 2 × n_heads × d_head × (t-1) floats
For t = 8192, n_heads = 32, d_head = 128:
KV_cache = 2 × 32 × 128 × 8191 = 67,075,072 floats ≈ 256 MB (float32)
Definition: Grouped Query Attention (GQA)
Instead of n_heads K and V projections, use n_kv_heads where n_kv_heads < n_heads.
Group the query heads: each group of (n_heads / n_kv_heads) query heads shares a single KV head.
Computation for head group g:
group_size = n_heads / n_kv_heads
For query heads in group g: (from index g*group_size to (g+1)*group_size - 1)
Q_g = Q W_{g*group_size}^Q, ..., Q W_{(g+1)*group_size-1}^Q (group_size distinct Q projections)
K_g = K W_g^K (ONE shared K projection)
V_g = V W_g^V (ONE shared V projection)
For each query head i in group g:
Scores_i = Q_i K_g^T / √d_head
Weights_i = softmax(Scores_i)
Output_i = Weights_i V_g
Example: Mistral 7B
- n_heads = 32
- n_kv_heads = 8
- group_size = 32 / 8 = 4
Group 1 (queries 1–4) all attend to KV head 1 Group 2 (queries 5–8) all attend to KV head 2 … Group 8 (queries 29–32) all attend to KV head 8
KV cache with GQA:
KV_cache_GQA = 2 × n_kv_heads × d_head × (t-1) floats
For t = 8192, n_kv_heads = 8, d_head = 128:
KV_cache_GQA = 2 × 8 × 128 × 8191 = 16,768,768 floats ≈ 64 MB (float32)
Memory reduction factor = 32 / 8 = 4×
Worked Example 1: GQA Computation
Setup:
- Sequence length n = 4 tokens
- Query heads n_heads = 4
- KV heads n_kv_heads = 2
- d_head = 2 (small for manual calculation)
- Group size = 4 / 2 = 2
Inputs (hand-chosen for simplicity):
Embeddings at token 1:
Query projections (4 heads):
Q_1 = [1, 0] Q_2 = [0, 1] Q_3 = [1, 1] Q_4 = [0.5, 0.5]
KV projections (2 heads):
K_1 = [1, 0] (shared by Q_1, Q_2)
K_2 = [0, 1] (shared by Q_3, Q_4)
V_1 = [2, 0]
V_2 = [0, 2]
At tokens 1 and 2 (showing token 1 and 2 keys only for brevity):
K_1 sequence: [[1, 0], [0.5, 0.5]]
K_2 sequence: [[0, 1], [0, 0.5]]
V_1 sequence: [[2, 0], [1, 0]]
V_2 sequence: [[0, 2], [0.5, 1]]
Compute attention for Query head 1 (attends to K_1, V_1):
Scores for token 1:
Q_1 = [1, 0]
K_1^T sequence = [[1, 0.5], [0, 0.5]] (keys from tokens 1, 2 as rows)
Attention scores = Q_1 @ K_1^T / √2
= [1, 0] @ [[1, 0.5], [0, 0.5]]^T / √2
= [1, 0] @ [1, 0; 0.5, 0.5]^T / √2
= [1, 0.5] / √2
≈ [0.707, 0.354]
Softmax([0.707, 0.354]):
exp(0.707) ≈ 2.028
exp(0.354) ≈ 1.425
sum ≈ 3.453
weights ≈ [2.028/3.453, 1.425/3.453] ≈ [0.587, 0.413]
Output = weights @ V_1
= [0.587, 0.413] @ [[2, 0], [1, 0]]
= [0.587 × 2 + 0.413 × 1, 0] + [0.587 × 0 + 0.413 × 0, 0]
= [1.174 + 0.413, 0]
= [1.587, 0]
Compute attention for Query head 3 (attends to K_2, V_2):
Scores for token 1:
Q_3 = [1, 1]
K_2^T sequence = [[0, 0], [1, 0.5]]
Attention scores = [1, 1] @ [[0, 1], [0, 0.5]]^T / √2
= [1, 1] @ [0, 0; 1, 0.5]^T / √2
= [0 + 1, 0 + 0.5] / √2
= [1, 0.5] / √2
≈ [0.707, 0.354]
Softmax: same as above ≈ [0.587, 0.413]
Output = [0.587, 0.413] @ [[0, 2], [0.5, 1]]
= [0 + 0.2065, 1.174 + 0.413]
= [0.2065, 1.587]
Key observation: Query heads 1 and 3 differ in their outputs, even though they share K and V. They have different Q projections, so they attend differently. This is the advantage of GQA: expressiveness without full KV cost.
Part 2: Sliding Window Attention
Definition
In Sliding Window Attention, each token attends only to the last W tokens (including itself):
For token at position t:
attend_from = max(0, t - W + 1)
attend_to = t
Attention over: tokens [attend_from, attend_from+1, ..., attend_to]
Number of tokens attended to: min(t + 1, W)
Attention formula (same as standard attention, but masked):
Scores[t, s] = (Q[t] @ K[s]^T) / √d_head if attend_from ≤ s ≤ t
= -∞ otherwise
Weights[t, :] = softmax(Scores[t, :])
Output[t] = Weights[t, :] @ V
The -∞ ensures attention to future or out-of-window tokens is 0 after softmax.
Worked Example 2: Sliding Window Attention
Setup:
- Sequence: 6 tokens [1, 2, 3, 4, 5, 6]
- Window size W = 3
- d_head = 1 (scalar) for simplicity
Inputs:
Queries: Q = [1, 2, 1, 3, 2, 4]
Keys: K = [1, 2, 1, 3, 2, 4] (same as Q for this example)
Values: V = [10, 20, 10, 30, 20, 40]
Compute attention for token 4 (position index 3):
Attend from: max(0, 3 - 3 + 1) = 1 Attend to: 3
Attend to tokens at indices [1, 2, 3] (values [2, 1, 3] for K, [20, 10, 30] for V)
Scores[3] = Q[3] @ [K[1], K[2], K[3]]^T / √1
= 3 @ [2, 1, 3]^T
= [6, 3, 9]
Softmax([6, 3, 9]):
exp(6) ≈ 403.43
exp(3) ≈ 20.09
exp(9) ≈ 8103.08
sum ≈ 8526.60
weights ≈ [403.43/8526.60, 20.09/8526.60, 8103.08/8526.60]
≈ [0.0473, 0.0024, 0.9503]
Output[3] = [0.0473, 0.0024, 0.9503] @ [20, 10, 30]
= 0.0473 × 20 + 0.0024 × 10 + 0.9503 × 30
= 0.946 + 0.024 + 28.509
= 29.479
Token 4 focuses almost entirely on token 4 itself (weight 0.9503). This makes sense: they have identical query and key values.
Compute attention for token 6 (position index 5):
Attend from: max(0, 5 - 3 + 1) = 3 Attend to: 5
Attend to tokens at indices [3, 4, 5] (K = [3, 2, 4], V = [30, 20, 40])
Scores[5] = 4 @ [3, 2, 4]^T
= [12, 8, 16]
Softmax([12, 8, 16]):
exp(12) ≈ 162,754
exp(8) ≈ 2,981
exp(16) ≈ 8,886,110
sum ≈ 8,951,845
weights ≈ [0.0182, 0.0003, 0.9815]
Output[5] = [0.0182, 0.0003, 0.9815] @ [30, 20, 40]
= 0.0182 × 30 + 0.0003 × 20 + 0.9815 × 40
= 0.546 + 0.006 + 39.260
= 39.812
Again, the token attends mostly to itself and its recent neighbours.
Part 3: Complexity Analysis
Attention Complexity
Standard attention (all pairs):
Number of (query, key) pairs: n × n = n²
Dimension per pair: d_head
Total FLOPs per layer: O(n² × d_head)
For n = 8,192, d_head = 128: 8,192² × 128 = 8,589,934,592 FLOPs
Sliding Window Attention:
Number of (query, key) pairs: n × W
(each query attends to W tokens)
Total FLOPs per layer: O(n × W × d_head)
For n = 8,192, W = 4,096, d_head = 128:
8,192 × 4,096 × 128 = 4,294,967,296 FLOPs
Reduction: 8,589,934,592 / 4,294,967,296 ≈ 2× per layer
Across 32 layers: ~64× total reduction in attention compute.
Memory Complexity
Standard attention (full sequence):
KV cache per layer: 2 × n_heads × d_head × n
For n_heads = 32, d_head = 128, n = 8,192:
KV_cache = 2 × 32 × 128 × 8,192 = 67,108,864 floats ≈ 256 MB
With GQA (8 KV heads):
KV_cache_GQA = 2 × 8 × 128 × 8,192 = 16,777,216 floats ≈ 64 MB
Memory reduction: 4×
With GQA + SWA (store only W tokens):
KV_cache_GQA_SWA = 2 × 8 × 128 × 4,096 = 8,388,608 floats ≈ 32 MB
Total reduction: 256 / 32 = 8×
Receptive Field with Sliding Windows
In a multilayer network, each layer has a local window of size W. But information propagates:
Layer 1: Each token sees up to W previous tokens
Layer 2: Each token can "see" information from k=2 layers
Receptive field ≈ 2 × W tokens
...
Layer k: Receptive field ≈ k × W tokens
For Mistral 7B: k = 32, W = 4,096
Effective receptive field ≈ 131,072 tokens
This means even with a 4,096 sliding window, the model can implicitly
attend to information from 131K tokens in the past via stacked layers.
This is why SWA doesn’t actually limit context length — it just changes how the model processes long sequences (locally, then propagates information up through layers).
Summary of Key Numbers (Mistral 7B)
| Metric | Value |
|---|---|
| n_heads | 32 |
| n_kv_heads | 8 |
| GQA reduction factor | 32/8 = 4× |
| d_head | 128 |
| W (window size) | 4,096 |
| Number of layers | 32 |
| Receptive field | 32 × 4,096 = 131,072 tokens |
| Training sequence length | Up to 32,768 tokens |
| KV cache per 8K sequence (float32) | 64 MB (vs 256 MB for MHA) |
| Attention FLOPs per layer | ~4.3B (vs ~8.6B for full attention) |
These numbers make inference feasible on consumer hardware.