Section 02

The Problem: Attention is Memory-Hungry

Mistral 7B 2023

The Problem: Attention is Memory-Hungry

Standard Transformer attention is computationally elegant but memory-inefficient. Two specific problems limit its use in 2023:

Problem 1: The KV Cache Blowup

Multi-Head Attention (MHA) requires:

  • n_heads independent K and V caches
  • Each cache stores d_head floats per token per layer

For a single-layer decoder with 32 heads and 128-dimensional heads, generating 8192 tokens:

KV cache = 2 layers × 32 heads × 128 dims × 8192 tokens
         = 67,108,864 floats
         = 256 MB per sample

Scale to 32 layers (the depth of models like LLaMA 2 7B), and one inference sample requires:

Total KV cache = 32 layers × 256 MB = 8.2 GB per sample

Run a batch of 32 samples (a typical inference server batch), and you need 256 GB of GPU memory just for KV caches — before loading model weights.

Why this is a real problem:

  • A single GPU (H100 with 80 GB) can hold the Mistral 7B weights (~28 GB in float16) and KV caches for only ~3 samples at full context
  • Longer sequences → linearly larger KV cache → quadratically fewer samples per GPU
  • Mobile/edge inference becomes impossible

Problem 2: Quadratic Attention Complexity

Standard attention computes Q @ K^T where both Q and K have length n (sequence length).

This is O(n²) — the number of pairs to attend to grows quadratically.

For a 32-token sequence: 32 × 32 = 1,024 pairs For a 8,192-token sequence: 8,192 × 8,192 = 67,108,864 pairs

At higher layers or in longer documents, this becomes prohibitive. Not just memory — it’s also slow. Each of the 32 layers performs O(n²) operations.

Problem 3: Naive Solutions Don’t Work

Why can’t we just use smaller models? LLaMA 2 7B exists, but it underperforms 13B on reasoning tasks. The gap is real:

  • LLaMA 2 7B: 43.35% on GSM8k (grade-school math)
  • LLaMA 2 13B: 56.44% on GSM8k (30% better)

Why can’t we use Multi-Query Attention (MQA)? MQA solves the KV cache problem by sharing a single KV head across all query heads:

Standard MHA: n_heads KV heads (e.g., 32)
MQA: 1 KV head
KV cache reduction: 32×

But MQA breaks quality. By forcing all query heads to attend to the same key-value pairs, you lose expressiveness. An MQA 7B model underperforms standard 7B.

Why can’t we just limit context to short sequences? Many tasks require reasoning over longer contexts:

  • Summarising documents (4K–32K tokens)
  • Code understanding (full files can be thousands of tokens)
  • Few-shot learning (context examples + prompt > 1K tokens)

A 512-token limit feels like writing novels on paper napkins.

The insight Mistral had

The solution isn’t to eliminate KV heads or ignore attention — it’s to compromise intelligently:

  1. Grouped Query Attention (GQA): Share KV heads, but not completely. Use n_kv_heads < n_heads where n_heads / n_kv_heads = 4 or 8. This preserves most of the expressiveness of MHA while achieving 4–8× KV cache reduction.

  2. Sliding Window Attention (SWA): Most of the information a token needs comes from nearby tokens, not the entire history. Limit each token to attend only to the last W tokens (e.g., W = 4,096). This reduces attention from O(n²) to O(n × W).

The two ideas work together. Combined, they reduce memory by 4–8× and compute by 4–8×, while maintaining or even improving quality.

The rest of the paper shows how.