Section 05

Worked Example: GQA in Action

Mistral 7B 2023

Worked Example: Grouped Query Attention in Action

This section walks through a complete, tiny example of Grouped Query Attention (GQA) step by step. You can verify every number by hand.

Setup

Parameters:

  • Sequence length: n = 3 tokens
  • Query heads: n_heads = 4
  • KV heads: n_kv_heads = 2
  • Dimension per head: d_head = 2
  • Group size: 4 / 2 = 2 (queries 1–2 share KV head 1; queries 3–4 share KV head 2)

Input embeddings (3 × embedding_dim = 3 × 4):

Token 1: [1, 0, 1, 0]
Token 2: [0, 1, 0, 1]
Token 3: [1, 1, 1, 1]

Projection matrices (we’ll define small ones):

Query projections (4 heads, each maps 4-dim embedding → 2-dim):

W_Q1 = [[1, 0],      W_Q2 = [[0, 1],      W_Q3 = [[1, 0],      W_Q4 = [[0, 1],
        [0, 1],              [1, 0],              [1, 1],              [1, 1],
        [0, 0],              [0, 0],              [0, 0],              [0, 0],
        [0, 0]]              [0, 0]]              [0, 0]]              [0, 0]]

KV projections (2 heads, each maps 4-dim → 2-dim):

W_K1 = [[1, 0],      W_K2 = [[1, 1],      W_V1 = [[1, 0],      W_V2 = [[0, 1],
        [0, 1],              [1, 1],              [0, 1],              [1, 0],
        [0, 0],              [0, 0],              [0, 0],              [0, 0],
        [0, 0]]              [0, 0]]              [0, 0]]              [0, 0]]

Step 1: Project embeddings to Q, K, V

Query projections:

Q1(token 1) = [1, 0, 1, 0] @ W_Q1 = [1, 0]
Q1(token 2) = [0, 1, 0, 1] @ W_Q1 = [0, 1]
Q1(token 3) = [1, 1, 1, 1] @ W_Q1 = [1, 1]

Q2(token 1) = [1, 0, 1, 0] @ W_Q2 = [0, 1]
Q2(token 2) = [0, 1, 0, 1] @ W_Q2 = [1, 0]
Q2(token 3) = [1, 1, 1, 1] @ W_Q2 = [1, 1]

Q3(token 1) = [1, 0, 1, 0] @ W_Q3 = [1, 0]
Q3(token 2) = [0, 1, 0, 1] @ W_Q3 = [1, 1]
Q3(token 3) = [1, 1, 1, 1] @ W_Q3 = [2, 2]

Q4(token 1) = [1, 0, 1, 0] @ W_Q4 = [0, 1]
Q4(token 2) = [0, 1, 0, 1] @ W_Q4 = [2, 0]
Q4(token 3) = [1, 1, 1, 1] @ W_Q4 = [2, 2]

KV projections:

K1(token 1) = [1, 0, 1, 0] @ W_K1 = [1, 0]
K1(token 2) = [0, 1, 0, 1] @ W_K1 = [0, 1]
K1(token 3) = [1, 1, 1, 1] @ W_K1 = [1, 1]

K2(token 1) = [1, 0, 1, 0] @ W_K2 = [1, 1]
K2(token 2) = [0, 1, 0, 1] @ W_K2 = [2, 1]
K2(token 3) = [1, 1, 1, 1] @ W_K2 = [2, 2]

V1(token 1) = [1, 0, 1, 0] @ W_V1 = [1, 0]
V1(token 2) = [0, 1, 0, 1] @ W_V1 = [0, 1]
V1(token 3) = [1, 1, 1, 1] @ W_V1 = [1, 1]

V2(token 1) = [1, 0, 1, 0] @ W_V2 = [0, 1]
V2(token 2) = [0, 1, 0, 1] @ W_V2 = [1, 0]
V2(token 3) = [1, 1, 1, 1] @ W_V2 = [1, 1]

Step 2: Compute attention for each query head at token 3

We’ll compute attention for token 3 only (you can attend to all previous tokens including itself).

Query Head 1 (attends to KV head 1)

Q1(token 3) = [1, 1]

K1 sequence = [[1, 0], [0, 1], [1, 1]]
V1 sequence = [[1, 0], [0, 1], [1, 1]]

Compute attention scores:

Scores = Q1(3) @ K1^T / √2
       = [1, 1] @ [[1, 0, 1],
                    [0, 1, 1]]^T / √2
       = [1, 1] @ [[1, 0],
                    [0, 1],
                    [1, 1]]^T / √2
       = [1×1 + 1×0, 1×0 + 1×1, 1×1 + 1×1] / √2
       = [1, 1, 2] / √2
       ≈ [0.707, 0.707, 1.414]

Apply softmax:

exp([0.707, 0.707, 1.414]) ≈ [2.028, 2.028, 4.113]

sum = 8.169

softmax ≈ [2.028/8.169, 2.028/8.169, 4.113/8.169]
        ≈ [0.248, 0.248, 0.504]

Compute output:

Output1(token 3) = [0.248, 0.248, 0.504] @ V1 sequence
                 = [0.248, 0.248, 0.504] @ [[1, 0], [0, 1], [1, 1]]
                 = [0.248×1 + 0.248×0 + 0.504×1,  0.248×0 + 0.248×1 + 0.504×1]
                 = [0.248 + 0.504,  0.248 + 0.504]
                 = [0.752, 0.752]

Query Head 2 (attends to KV head 1 — same as head 1)

Q2(token 3) = [1, 1]  (same as Q1 in this example)

Since Q2 and Q1 are identical and they both attend to the same K1, V1, the output will be the same:

Output2(token 3) = [0.752, 0.752]

(In a real model, Q1 and Q2 would be different, so outputs would differ.)

Query Head 3 (attends to KV head 2)

Q3(token 3) = [2, 2]

K2 sequence = [[1, 1], [2, 1], [2, 2]]
V2 sequence = [[0, 1], [1, 0], [1, 1]]

Compute attention scores:

Scores = [2, 2] @ [[1, 2, 2],
                    [1, 1, 2]]^T / √2
       = [2×1 + 2×1,  2×2 + 2×1,  2×2 + 2×2] / √2
       = [4, 6, 8] / √2
       ≈ [2.828, 4.243, 5.657]

Apply softmax:

exp([2.828, 4.243, 5.657]) ≈ [16.93, 69.80, 286.25]

sum = 372.98

softmax ≈ [16.93/372.98, 69.80/372.98, 286.25/372.98]
        ≈ [0.0454, 0.1872, 0.7675]

Compute output:

Output3(token 3) = [0.0454, 0.1872, 0.7675] @ V2 sequence
                 = [0.0454, 0.1872, 0.7675] @ [[0, 1], [1, 0], [1, 1]]
                 = [0.0454×0 + 0.1872×1 + 0.7675×1,
                    0.0454×1 + 0.1872×0 + 0.7675×1]
                 = [0.1872 + 0.7675,  0.0454 + 0.7675]
                 = [0.9547, 0.8129]

Query Head 4 (attends to KV head 2 — same as head 3)

Q4(token 3) = [2, 2]  (same as Q3 in this example)

Output will be the same:

Output4(token 3) = [0.9547, 0.8129]

Step 3: Concatenate and project

Concatenate all 4 head outputs:

Concatenated = [Output1, Output2, Output3, Output4]
             = [0.752, 0.752, 0.752, 0.752, 0.9547, 0.8129, 0.9547, 0.8129]
             (8-dimensional, 4 heads × 2 dims each)

This is the raw output of the multi-head GQA layer. In a real model, this would be projected back to the original embedding dimension via a final output projection matrix W_O.

Key Insight: Memory Saving

In this example:

  • Standard MHA: 4 query heads × 4 KV heads = 16 (K, V) pairs to store
  • GQA: 4 query heads × 2 KV heads = 8 (K, V) pairs to store
  • Reduction: 2×

Scale this to Mistral 7B (32 Q heads, 8 KV heads = 4× reduction), and you save massive memory during inference.

Verification by Hand

You can verify these calculations with a simple Python snippet:

import numpy as np

# Token 3 attention with head 3
Q3_t3 = np.array([2, 2])
K2_seq = np.array([[1, 1], [2, 1], [2, 2]])
V2_seq = np.array([[0, 1], [1, 0], [1, 1]])

scores = Q3_t3 @ K2_seq.T / np.sqrt(2)
print(f"Scores: {scores}")  # [2.828, 4.243, 5.657]

weights = np.exp(scores) / np.sum(np.exp(scores))
print(f"Weights: {weights}")  # [0.0454, 0.1872, 0.7675]

output = weights @ V2_seq
print(f"Output: {output}")  # [0.9547, 0.8129]

This worked example shows exactly how GQA reduces parameters and memory while preserving the core attention mechanism. The trade-off in expressiveness is small (only 2× in this toy example, 4× in Mistral), and empirically, models trained with GQA maintain or exceed the quality of standard MHA.