Section 05

Worked Example: Ring Attention Step-by-Step

Ring Attention with Blockwise Transformers for Near-Infinite Context 2023

Worked Example: Ring Attention Step-by-Step

This section walks through a concrete example of Ring Attention with 4 GPUs and 8 tokens. You can visualise how KV chunks circulate and how each GPU accumulates attention.

Setup

Parameters:

  • Total sequence: 8 tokens
  • Number of GPUs: 4 (each GPU holds 2 tokens)
  • Embedding dimension: 2
  • 1 attention head (for simplicity)

Initial Distribution:

GPU 0: tokens [0, 1]
GPU 1: tokens [2, 3]
GPU 2: tokens [4, 5]
GPU 3: tokens [6, 7]

Embeddings:

Token 0: [1, 0]
Token 1: [0, 1]
Token 2: [1, 1]
Token 3: [1, 2]
Token 4: [2, 1]
Token 5: [2, 2]
Token 6: [3, 1]
Token 7: [3, 3]

Projection Matrices (identity for simplicity):

  • Q = K = V = I (identity), so embedding = query = key = value

GPU Initial State:

GPU 0: Q[0:2] = [[1,0], [0,1]]    KV[0:2] = (same)
GPU 1: Q[2:4] = [[1,1], [1,2]]    KV[2:4] = (same)
GPU 2: Q[4:6] = [[2,1], [2,2]]    KV[4:6] = (same)
GPU 3: Q[6:8] = [[3,1], [3,3]]    KV[6:8] = (same)

Round 0: Each GPU Computes with Its Own KV

GPU 0: Q[0:2] @ KV[0:2]

Q[0:2] = [[1, 0], [0, 1]]
K[0:2] = [[1, 0], [0, 1]]
V[0:2] = [[1, 0], [0, 1]]

Attention scores = Q @ K^T / √2:
  Q[0] @ K^T = [1, 0] @ [[1, 0], [0, 1]]^T = [1, 0]
  Q[1] @ K^T = [0, 1] @ [[1, 0], [0, 1]]^T = [0, 1]

Scores = [[1, 0], [0, 1]] / √2 ≈ [[0.707, 0], [0, 0.707]]

Softmax (per row):
  Row 0: softmax([0.707, 0]) ≈ [0.731, 0.269]
  Row 1: softmax([0, 0.707]) ≈ [0.269, 0.731]

Weights = [[0.731, 0.269], [0.269, 0.731]]

Output_0_round_0 = Weights @ V[0:2]
                 = [[0.731, 0.269], [0.269, 0.731]] @ [[1, 0], [0, 1]]
                 = [[0.731, 0.269], [0.269, 0.731]]

Communication: GPU 0 sends KV[0:2] to GPU 1, receives KV[3] from GPU 3.

GPU 1: Q[2:4] @ KV[2:4]

Q[2:4] = [[1, 1], [1, 2]]
K[2:4] = [[1, 1], [1, 2]]

Scores = Q @ K^T / √2:
  [1, 1] @ [[1, 1], [1, 2]]^T = [2, 3] / √2 ≈ [1.414, 2.121]
  [1, 2] @ [[1, 1], [1, 2]]^T = [2, 5] / √2 ≈ [1.414, 3.536]

Softmax:
  Row 0: softmax([1.414, 2.121]) ≈ [0.310, 0.690]
  Row 1: softmax([1.414, 3.536]) ≈ [0.142, 0.858]

Output_1_round_0 = [[0.310, 0.690], [0.142, 0.858]] @ V[2:4]
                 = [[0.310, 0.690], [0.142, 0.858]] @ [[1, 1], [1, 2]]
                 = [[0.69, 1.38], [0.858, 1.716]]

Communication: GPU 1 sends KV[2:4] to GPU 2, receives KV[0:2] from GPU 0.

GPU 2, GPU 3: Similar

(Skipping detailed calculations for brevity, but same pattern.)

After Round 0:

  • Each GPU has computed attention with its own KV chunk
  • KV chunks are circulating: GPU i sends to GPU (i+1) % 4

Round 1: Each GPU Computes with Neighbour’s KV

Now GPU 0 has KV[3] (from GPU 3), so it computes Q[0:2] @ KV[6:8].

GPU 0: Q[0:2] @ KV[6:8]

Q[0:2] = [[1, 0], [0, 1]]
K[6:8] = [[3, 1], [3, 3]]

Scores = Q @ K^T / √2:
  [1, 0] @ [[3, 1], [3, 3]]^T = [3, 3] / √2 ≈ [2.121, 2.121]
  [0, 1] @ [[3, 1], [3, 3]]^T = [1, 3] / √2 ≈ [0.707, 2.121]

Softmax:
  Row 0: softmax([2.121, 2.121]) ≈ [0.5, 0.5]
  Row 1: softmax([0.707, 2.121]) ≈ [0.219, 0.781]

Output_0_round_1 = [[0.5, 0.5], [0.219, 0.781]] @ V[6:8]
                 = [[0.5, 0.5], [0.219, 0.781]] @ [[3, 1], [3, 3]]
                 = [[3, 2], [2.438, 2.562]]

Accumulate: Output_0 += Output_0_round_1
           = [[0.731, 0.269], [0.269, 0.731]] + [[3, 2], [2.438, 2.562]]
           = [[3.731, 2.269], [2.707, 3.293]]

GPU 1: Q[2:4] @ KV[0:2]

Q[2:4] = [[1, 1], [1, 2]]
K[0:2] = [[1, 0], [0, 1]]

Scores = Q @ K^T / √2:
  [1, 1] @ [[1, 0], [0, 1]]^T = [1, 1] / √2 ≈ [0.707, 0.707]
  [1, 2] @ [[1, 0], [0, 1]]^T = [1, 2] / √2 ≈ [0.707, 1.414]

Softmax:
  Row 0: softmax([0.707, 0.707]) ≈ [0.5, 0.5]
  Row 1: softmax([0.707, 1.414]) ≈ [0.269, 0.731]

Output_1_round_1 = [[0.5, 0.5], [0.269, 0.731]] @ V[0:2]
                 = [[0.5, 0.5], [0.269, 0.731]]

Accumulate: Output_1 += Output_1_round_1
           = [[0.69, 1.38], [0.142, 0.858]] + [[0.5, 0.5], [0.269, 0.731]]
           = [[1.19, 1.88], [0.411, 1.589]]

Round 2 & 3: Continue Circulation

KV chunks continue circulating. After Round 1:

  • GPU 0 has received KV[2:4] (from GPU 3, originally from GPU 1)
  • GPU 1 has received KV[3:4] (from GPU 0)
  • And so on…

In Rounds 2 and 3, each GPU computes with the remaining KV chunks it hasn’t yet processed.

After all 4 rounds, every GPU has computed:

GPU 0: Q[0:2] @ KV[0:2], Q[0:2] @ KV[6:8], Q[0:2] @ KV[4:6], Q[0:2] @ KV[2:4]
GPU 1: Q[2:4] @ KV[2:4], Q[2:4] @ KV[0:2], Q[2:4] @ KV[6:8], Q[2:4] @ KV[4:6]
GPU 2: Q[4:6] @ KV[4:6], Q[4:6] @ KV[2:4], Q[4:6] @ KV[0:2], Q[4:6] @ KV[6:8]
GPU 3: Q[6:8] @ KV[6:8], Q[6:8] @ KV[4:6], Q[6:8] @ KV[2:4], Q[6:8] @ KV[0:2]

This is equivalent to:

Full Attention = Q @ K^T @ V
where Q and KV span all 8 tokens (full sequence).

Key Observations

  1. No replication: No KV chunk is stored on more than one GPU at a time (after initial placement). KV chunks circulate, not duplicate.

  2. Balanced communication: Each GPU sends and receives the same amount of data each round.

  3. Full attention achieved: Even though Q and KV are split across GPUs, the final output is equivalent to full attention on a single GPU.

  4. Deterministic ordering: GPU i always receives KV from GPU (i-1) in a predictable pattern. No random access.

  5. Scalability: This pattern works for any number of GPUs P and any sequence length n (where n is divisible by P).

Numerical Verification

You could verify that the final accumulated output from Ring Attention matches what you’d get from computing full attention on all 8 tokens simultaneously on a single GPU. The online softmax trick ensures numerical correctness across blocks.

Memory Footprint

In this example:

  • Single GPU (all 8 tokens): 8 × 2 = 16 values (plus intermediate buffers)
  • Ring Attention, 4 GPUs: (8/4) × 2 = 4 values per GPU = 16 values total

Memory distributed across 4 GPUs: 4 values per GPU (4× reduction vs single GPU holding all 16).

This scales to real sequences: 1 million tokens on 8 GPUs = 125K tokens per GPU. Feasible on consumer hardware.