Section 03

The Idea: GQA and Sliding Window Attention

Mistral 7B 2023

The Idea: Grouped Query Attention and Sliding Window Attention

Mistral’s innovation is deceptively simple: two small changes to attention that unlock massive efficiency gains.

Idea 1: Grouped Query Attention (GQA)

The Standard Approach (Multi-Head Attention)

In Multi-Head Attention, every query head has its own key and value head:

Query heads:  Q₁  Q₂  Q₃  Q₄  Q₅  Q₆  ...  Q₃₂
Key heads:    K₁  K₂  K₃  K₄  K₅  K₆  ...  K₃₂
Value heads:  V₁  V₂  V₃  V₄  V₅  V₆  ...  V₃₂

Each Q_i attends to its own K_i and V_i. This is expressive — each head has a personalised set of keys and values. But it requires caching all 32 K and V heads.

Multi-Query Attention (The Problem)

MQA reduces to:

Query heads:  Q₁  Q₂  Q₃  Q₄  Q₅  Q₆  ...  Q₃₂
Key head:     K
Value head:   V

All query heads attend to the same K and V. Now you only cache one K and V head — 32× memory reduction! But the cost is high: all 32 heads are forced to attend to the same information. This reduces model expressiveness and hurts quality.

Grouped Query Attention (The Solution)

GQA groups the query heads:

Query heads:     Q₁  Q₂  Q₃  Q₄  Q₅  Q₆  Q₇  Q₈  ... Q₃₂
Key heads:       K₁        K₂        K₃        K₄
Value heads:     V₁        V₂        V₃        V₄
                 └──┬──┘   └──┬──┘   └──┬──┘   └──┬──┘
                 Group 1   Group 2   Group 3   Group 4

In this example: 32 Q heads, 8 KV heads, groups of 4.

  • Query heads 1, 2, 3, 4 all attend to the same K₁, V₁
  • Query heads 5, 6, 7, 8 all attend to the same K₂, V₂
  • And so on…

Why this works: Each group of Q heads can still attend to different features (Q₁, Q₂, Q₃, Q₄ are different), but they share the same key-value representation. This is like having 4 distinct people asking questions, but the answers come from the same reference library. Some expressiveness is lost compared to full MHA, but the quality drop is small — and the memory saving is huge.

The math:

  • Standard MHA KV cache: 2 × n_heads × d_head × seq_len
  • GQA KV cache: 2 × n_kv_heads × d_head × seq_len
  • Memory reduction factor: n_heads / n_kv_heads

For Mistral 7B: 32 / 8 = 4× reduction.

Idea 2: Sliding Window Attention (SWA)

The Problem with Full Attention

Attention forces each token to look at every previous token:

Token 100 attends to:  [Token 1, Token 2, Token 3, ..., Token 99]

       This requires O(n) memory and compute per token

For a 8,192-token document:

  • Token 8,192 attends to 8,191 other tokens
  • Attention matrix: 8,192 × 8,192 = 67M entries

This is O(n²) — quadratic in sequence length.

The Insight: Information is Local

A key observation: most of the information a token needs comes from nearby tokens, not the entire history.

Think of a novel. When you’re reading page 200, you rarely need to remember specific details from page 5. You remember the gist (which came via summaries and references in intermediate pages), but not exact passages.

In language, the same is true:

  • To understand “the cat jumped over the fence”, you need nearby context (“the cat”, “the fence”), not the entire document history
  • Long-range dependencies do occur (“John left London… he never came back to the city”), but they’re rare and already embedded in intermediate summaries

Sliding Window Attention (The Solution)

SWA formalises this: each token attends only to the last W tokens (a sliding window), not all previous tokens:

Window size W = 4096

Token 100 attends to:  [Token 1-4096]  (or [Token 1] if token 100 < 4096)
Token 5000 attends to: [Token 996-5000] (only 4096 tokens)
Token 8192 attends to: [Token 4097-8192] (exactly 4096)

Why long-range dependencies still work:

With deep networks (e.g., 32 layers), information propagates:

Layer 1:  Token 8192 attends to tokens [4097-8192]
Layer 2:  Token 8192 can now "see" tokens [1-8192]
          (because tokens [4097-8192] already attend to earlier tokens)

This is called the receptive field. After k layers, each token can indirectly attend to k × W previous tokens:

Receptive field = k × W
Mistral 7B: 32 layers × 4,096 window = 131,072 tokens of effective context

So even with W = 4,096, you don’t lose long-range information. It just takes more layers to propagate it.

Complexity Reduction

  • Standard attention per layer: O(n²) — all n tokens attend to all n tokens
  • SWA per layer: O(n × W) — each of n tokens attends to W tokens
  • For n = 8,192, W = 4,096: reduction from 67M to 33M operations per layer (~2×)
  • Across all 32 layers: ~64× speedup in total attention compute

Combining Both Ideas

The brilliance of Mistral is combining both:

  1. GQA: 4× memory reduction in KV cache
  2. SWA: 2–4× compute reduction + limits KV cache to W tokens instead of seq_len

Together:

  • KV cache memory: 4× smaller per token + limited to 4K tokens instead of 8K = 8× smaller in typical cases
  • Attention compute: 4–8× faster
  • Inference throughput: 2–4× faster (accounting for other operations)
  • Quality: same or better than standard attention (because architecture is cleaner)

The Indian Analogy Recap

GQA: Imagine 32 university students (query heads) writing question papers. Instead of each student writing on their own question bank (MHA), they’re grouped into 8 groups of 4. Each group of 4 students uses the same reference library (K and V). The group can still write 4 different papers (Q₁, Q₂, Q₃, Q₄ are different), but they draw from common sources. Cost is reduced, quality loss is minimal.

SWA: Imagine you’re reading a long legal document, but your RAM only holds the last 10 pages. You don’t memorise everything from the start. Your understanding of page 100 comes from careful reading of pages 91–100, plus the fact that intermediate pages referenced earlier sections. By the time you reach page 100, you have an implicit grasp of page 1–10 (via cascading references). You never read all 100 pages simultaneously — you only ever focus on the last 10.

This is exactly how Mistral 7B works.