Section 04

The Math: Formal Definitions and Worked Examples

Mistral 7B 2023

The Math: Formal Definitions and Worked Examples

This section formalises Grouped Query Attention and Sliding Window Attention with precise equations and numerical examples you can verify by hand.

Prerequisites

Before reading this section, review:

Part 1: Grouped Query Attention

Definition: Standard Multi-Head Attention (MHA)

For a sequence of length n and d embedding dimensions, split into n_heads heads:

Inputs:

  • Query matrix: Q ∈ ℝ^(n × d)
  • Key matrix: K ∈ ℝ^(n × d)
  • Value matrix: V ∈ ℝ^(n × d)
  • d_head = d / n_heads (dimension per head)

Computation for head i:

Q_i = Q W_i^Q  ∈ ℝ^(n × d_head)     (project to head i)
K_i = K W_i^K  ∈ ℝ^(n × d_head)
V_i = V W_i^V  ∈ ℝ^(n × d_head)

Scores_i = Q_i K_i^T / √d_head  ∈ ℝ^(n × n)

Weights_i = softmax(Scores_i)  ∈ ℝ^(n × n)

Output_i = Weights_i V_i  ∈ ℝ^(n × d_head)

Concatenate all heads:

Output = Concat(Output_1, ..., Output_n_heads)  ∈ ℝ^(n × d)
Output = Output W_O  (project back)

KV cache during inference (autoregressive generation):

At step t (generating token t), we need K and V from all previous steps 1 to t-1. Cache sizes:

KV_cache = 2 × n_heads × d_head × (t-1)  floats

For t = 8192, n_heads = 32, d_head = 128:
KV_cache = 2 × 32 × 128 × 8191 = 67,075,072 floats ≈ 256 MB (float32)

Definition: Grouped Query Attention (GQA)

Instead of n_heads K and V projections, use n_kv_heads where n_kv_heads < n_heads.

Group the query heads: each group of (n_heads / n_kv_heads) query heads shares a single KV head.

Computation for head group g:

group_size = n_heads / n_kv_heads

For query heads in group g: (from index g*group_size to (g+1)*group_size - 1)
  Q_g = Q W_{g*group_size}^Q, ..., Q W_{(g+1)*group_size-1}^Q  (group_size distinct Q projections)
  
  K_g = K W_g^K  (ONE shared K projection)
  V_g = V W_g^V  (ONE shared V projection)
  
  For each query head i in group g:
    Scores_i = Q_i K_g^T / √d_head
    Weights_i = softmax(Scores_i)
    Output_i = Weights_i V_g

Example: Mistral 7B

  • n_heads = 32
  • n_kv_heads = 8
  • group_size = 32 / 8 = 4

Group 1 (queries 1–4) all attend to KV head 1 Group 2 (queries 5–8) all attend to KV head 2 … Group 8 (queries 29–32) all attend to KV head 8

KV cache with GQA:

KV_cache_GQA = 2 × n_kv_heads × d_head × (t-1)  floats

For t = 8192, n_kv_heads = 8, d_head = 128:
KV_cache_GQA = 2 × 8 × 128 × 8191 = 16,768,768 floats ≈ 64 MB (float32)

Memory reduction factor = 32 / 8 = 4×

Worked Example 1: GQA Computation

Setup:

  • Sequence length n = 4 tokens
  • Query heads n_heads = 4
  • KV heads n_kv_heads = 2
  • d_head = 2 (small for manual calculation)
  • Group size = 4 / 2 = 2

Inputs (hand-chosen for simplicity):

Embeddings at token 1:
  Query projections (4 heads):
    Q_1 = [1, 0]    Q_2 = [0, 1]    Q_3 = [1, 1]    Q_4 = [0.5, 0.5]
  
  KV projections (2 heads):
    K_1 = [1, 0]    (shared by Q_1, Q_2)
    K_2 = [0, 1]    (shared by Q_3, Q_4)
    
    V_1 = [2, 0]
    V_2 = [0, 2]

At tokens 1 and 2 (showing token 1 and 2 keys only for brevity):
  K_1 sequence: [[1, 0], [0.5, 0.5]]
  K_2 sequence: [[0, 1], [0, 0.5]]
  V_1 sequence: [[2, 0], [1, 0]]
  V_2 sequence: [[0, 2], [0.5, 1]]

Compute attention for Query head 1 (attends to K_1, V_1):

Scores for token 1:

Q_1 = [1, 0]
K_1^T sequence = [[1, 0.5], [0, 0.5]] (keys from tokens 1, 2 as rows)

Attention scores = Q_1 @ K_1^T / √2
                 = [1, 0] @ [[1, 0.5], [0, 0.5]]^T / √2
                 = [1, 0] @ [1, 0; 0.5, 0.5]^T / √2
                 = [1, 0.5] / √2
                 ≈ [0.707, 0.354]

Softmax([0.707, 0.354]):
  exp(0.707) ≈ 2.028
  exp(0.354) ≈ 1.425
  sum ≈ 3.453
  
  weights ≈ [2.028/3.453, 1.425/3.453] ≈ [0.587, 0.413]

Output = weights @ V_1
       = [0.587, 0.413] @ [[2, 0], [1, 0]]
       = [0.587 × 2 + 0.413 × 1, 0] + [0.587 × 0 + 0.413 × 0, 0]
       = [1.174 + 0.413, 0]
       = [1.587, 0]

Compute attention for Query head 3 (attends to K_2, V_2):

Scores for token 1:

Q_3 = [1, 1]
K_2^T sequence = [[0, 0], [1, 0.5]]

Attention scores = [1, 1] @ [[0, 1], [0, 0.5]]^T / √2
                 = [1, 1] @ [0, 0; 1, 0.5]^T / √2
                 = [0 + 1, 0 + 0.5] / √2
                 = [1, 0.5] / √2
                 ≈ [0.707, 0.354]

Softmax: same as above ≈ [0.587, 0.413]

Output = [0.587, 0.413] @ [[0, 2], [0.5, 1]]
       = [0 + 0.2065, 1.174 + 0.413]
       = [0.2065, 1.587]

Key observation: Query heads 1 and 3 differ in their outputs, even though they share K and V. They have different Q projections, so they attend differently. This is the advantage of GQA: expressiveness without full KV cost.


Part 2: Sliding Window Attention

Definition

In Sliding Window Attention, each token attends only to the last W tokens (including itself):

For token at position t:

attend_from = max(0, t - W + 1)
attend_to = t

Attention over: tokens [attend_from, attend_from+1, ..., attend_to]
Number of tokens attended to: min(t + 1, W)

Attention formula (same as standard attention, but masked):

Scores[t, s] = (Q[t] @ K[s]^T) / √d_head    if attend_from ≤ s ≤ t
             = -∞                            otherwise

Weights[t, :] = softmax(Scores[t, :])

Output[t] = Weights[t, :] @ V

The -∞ ensures attention to future or out-of-window tokens is 0 after softmax.

Worked Example 2: Sliding Window Attention

Setup:

  • Sequence: 6 tokens [1, 2, 3, 4, 5, 6]
  • Window size W = 3
  • d_head = 1 (scalar) for simplicity

Inputs:

Queries:   Q = [1, 2, 1, 3, 2, 4]
Keys:      K = [1, 2, 1, 3, 2, 4]  (same as Q for this example)
Values:    V = [10, 20, 10, 30, 20, 40]

Compute attention for token 4 (position index 3):

Attend from: max(0, 3 - 3 + 1) = 1 Attend to: 3

Attend to tokens at indices [1, 2, 3] (values [2, 1, 3] for K, [20, 10, 30] for V)

Scores[3] = Q[3] @ [K[1], K[2], K[3]]^T / √1
          = 3 @ [2, 1, 3]^T
          = [6, 3, 9]

Softmax([6, 3, 9]):
  exp(6) ≈ 403.43
  exp(3) ≈ 20.09
  exp(9) ≈ 8103.08
  sum ≈ 8526.60
  
  weights ≈ [403.43/8526.60, 20.09/8526.60, 8103.08/8526.60]
          ≈ [0.0473, 0.0024, 0.9503]

Output[3] = [0.0473, 0.0024, 0.9503] @ [20, 10, 30]
          = 0.0473 × 20 + 0.0024 × 10 + 0.9503 × 30
          = 0.946 + 0.024 + 28.509
          = 29.479

Token 4 focuses almost entirely on token 4 itself (weight 0.9503). This makes sense: they have identical query and key values.

Compute attention for token 6 (position index 5):

Attend from: max(0, 5 - 3 + 1) = 3 Attend to: 5

Attend to tokens at indices [3, 4, 5] (K = [3, 2, 4], V = [30, 20, 40])

Scores[5] = 4 @ [3, 2, 4]^T
          = [12, 8, 16]

Softmax([12, 8, 16]):
  exp(12) ≈ 162,754
  exp(8) ≈ 2,981
  exp(16) ≈ 8,886,110
  sum ≈ 8,951,845
  
  weights ≈ [0.0182, 0.0003, 0.9815]

Output[5] = [0.0182, 0.0003, 0.9815] @ [30, 20, 40]
          = 0.0182 × 30 + 0.0003 × 20 + 0.9815 × 40
          = 0.546 + 0.006 + 39.260
          = 39.812

Again, the token attends mostly to itself and its recent neighbours.


Part 3: Complexity Analysis

Attention Complexity

Standard attention (all pairs):

Number of (query, key) pairs: n × n = n²
Dimension per pair: d_head
Total FLOPs per layer: O(n² × d_head)

For n = 8,192, d_head = 128: 8,192² × 128 = 8,589,934,592 FLOPs

Sliding Window Attention:

Number of (query, key) pairs: n × W
(each query attends to W tokens)

Total FLOPs per layer: O(n × W × d_head)

For n = 8,192, W = 4,096, d_head = 128:
8,192 × 4,096 × 128 = 4,294,967,296 FLOPs

Reduction: 8,589,934,592 / 4,294,967,296 ≈ 2× per layer

Across 32 layers: ~64× total reduction in attention compute.

Memory Complexity

Standard attention (full sequence):

KV cache per layer: 2 × n_heads × d_head × n
For n_heads = 32, d_head = 128, n = 8,192:
KV_cache = 2 × 32 × 128 × 8,192 = 67,108,864 floats ≈ 256 MB

With GQA (8 KV heads):

KV_cache_GQA = 2 × 8 × 128 × 8,192 = 16,777,216 floats ≈ 64 MB
Memory reduction: 4×

With GQA + SWA (store only W tokens):

KV_cache_GQA_SWA = 2 × 8 × 128 × 4,096 = 8,388,608 floats ≈ 32 MB
Total reduction: 256 / 32 = 8×

Receptive Field with Sliding Windows

In a multilayer network, each layer has a local window of size W. But information propagates:

Layer 1: Each token sees up to W previous tokens
Layer 2: Each token can "see" information from k=2 layers
         Receptive field ≈ 2 × W tokens
...
Layer k: Receptive field ≈ k × W tokens

For Mistral 7B: k = 32, W = 4,096
Effective receptive field ≈ 131,072 tokens

This means even with a 4,096 sliding window, the model can implicitly
attend to information from 131K tokens in the past via stacked layers.

This is why SWA doesn’t actually limit context length — it just changes how the model processes long sequences (locally, then propagates information up through layers).


Summary of Key Numbers (Mistral 7B)

MetricValue
n_heads32
n_kv_heads8
GQA reduction factor32/8 = 4×
d_head128
W (window size)4,096
Number of layers32
Receptive field32 × 4,096 = 131,072 tokens
Training sequence lengthUp to 32,768 tokens
KV cache per 8K sequence (float32)64 MB (vs 256 MB for MHA)
Attention FLOPs per layer~4.3B (vs ~8.6B for full attention)

These numbers make inference feasible on consumer hardware.