Worked Example: Ring Attention Step-by-Step
This section walks through a concrete example of Ring Attention with 4 GPUs and 8 tokens. You can visualise how KV chunks circulate and how each GPU accumulates attention.
Setup
Parameters:
- Total sequence: 8 tokens
- Number of GPUs: 4 (each GPU holds 2 tokens)
- Embedding dimension: 2
- 1 attention head (for simplicity)
Initial Distribution:
GPU 0: tokens [0, 1]
GPU 1: tokens [2, 3]
GPU 2: tokens [4, 5]
GPU 3: tokens [6, 7]
Embeddings:
Token 0: [1, 0]
Token 1: [0, 1]
Token 2: [1, 1]
Token 3: [1, 2]
Token 4: [2, 1]
Token 5: [2, 2]
Token 6: [3, 1]
Token 7: [3, 3]
Projection Matrices (identity for simplicity):
- Q = K = V = I (identity), so embedding = query = key = value
GPU Initial State:
GPU 0: Q[0:2] = [[1,0], [0,1]] KV[0:2] = (same)
GPU 1: Q[2:4] = [[1,1], [1,2]] KV[2:4] = (same)
GPU 2: Q[4:6] = [[2,1], [2,2]] KV[4:6] = (same)
GPU 3: Q[6:8] = [[3,1], [3,3]] KV[6:8] = (same)
Round 0: Each GPU Computes with Its Own KV
GPU 0: Q[0:2] @ KV[0:2]
Q[0:2] = [[1, 0], [0, 1]]
K[0:2] = [[1, 0], [0, 1]]
V[0:2] = [[1, 0], [0, 1]]
Attention scores = Q @ K^T / √2:
Q[0] @ K^T = [1, 0] @ [[1, 0], [0, 1]]^T = [1, 0]
Q[1] @ K^T = [0, 1] @ [[1, 0], [0, 1]]^T = [0, 1]
Scores = [[1, 0], [0, 1]] / √2 ≈ [[0.707, 0], [0, 0.707]]
Softmax (per row):
Row 0: softmax([0.707, 0]) ≈ [0.731, 0.269]
Row 1: softmax([0, 0.707]) ≈ [0.269, 0.731]
Weights = [[0.731, 0.269], [0.269, 0.731]]
Output_0_round_0 = Weights @ V[0:2]
= [[0.731, 0.269], [0.269, 0.731]] @ [[1, 0], [0, 1]]
= [[0.731, 0.269], [0.269, 0.731]]
Communication: GPU 0 sends KV[0:2] to GPU 1, receives KV[3] from GPU 3.
GPU 1: Q[2:4] @ KV[2:4]
Q[2:4] = [[1, 1], [1, 2]]
K[2:4] = [[1, 1], [1, 2]]
Scores = Q @ K^T / √2:
[1, 1] @ [[1, 1], [1, 2]]^T = [2, 3] / √2 ≈ [1.414, 2.121]
[1, 2] @ [[1, 1], [1, 2]]^T = [2, 5] / √2 ≈ [1.414, 3.536]
Softmax:
Row 0: softmax([1.414, 2.121]) ≈ [0.310, 0.690]
Row 1: softmax([1.414, 3.536]) ≈ [0.142, 0.858]
Output_1_round_0 = [[0.310, 0.690], [0.142, 0.858]] @ V[2:4]
= [[0.310, 0.690], [0.142, 0.858]] @ [[1, 1], [1, 2]]
= [[0.69, 1.38], [0.858, 1.716]]
Communication: GPU 1 sends KV[2:4] to GPU 2, receives KV[0:2] from GPU 0.
GPU 2, GPU 3: Similar
(Skipping detailed calculations for brevity, but same pattern.)
After Round 0:
- Each GPU has computed attention with its own KV chunk
- KV chunks are circulating: GPU i sends to GPU (i+1) % 4
Round 1: Each GPU Computes with Neighbour’s KV
Now GPU 0 has KV[3] (from GPU 3), so it computes Q[0:2] @ KV[6:8].
GPU 0: Q[0:2] @ KV[6:8]
Q[0:2] = [[1, 0], [0, 1]]
K[6:8] = [[3, 1], [3, 3]]
Scores = Q @ K^T / √2:
[1, 0] @ [[3, 1], [3, 3]]^T = [3, 3] / √2 ≈ [2.121, 2.121]
[0, 1] @ [[3, 1], [3, 3]]^T = [1, 3] / √2 ≈ [0.707, 2.121]
Softmax:
Row 0: softmax([2.121, 2.121]) ≈ [0.5, 0.5]
Row 1: softmax([0.707, 2.121]) ≈ [0.219, 0.781]
Output_0_round_1 = [[0.5, 0.5], [0.219, 0.781]] @ V[6:8]
= [[0.5, 0.5], [0.219, 0.781]] @ [[3, 1], [3, 3]]
= [[3, 2], [2.438, 2.562]]
Accumulate: Output_0 += Output_0_round_1
= [[0.731, 0.269], [0.269, 0.731]] + [[3, 2], [2.438, 2.562]]
= [[3.731, 2.269], [2.707, 3.293]]
GPU 1: Q[2:4] @ KV[0:2]
Q[2:4] = [[1, 1], [1, 2]]
K[0:2] = [[1, 0], [0, 1]]
Scores = Q @ K^T / √2:
[1, 1] @ [[1, 0], [0, 1]]^T = [1, 1] / √2 ≈ [0.707, 0.707]
[1, 2] @ [[1, 0], [0, 1]]^T = [1, 2] / √2 ≈ [0.707, 1.414]
Softmax:
Row 0: softmax([0.707, 0.707]) ≈ [0.5, 0.5]
Row 1: softmax([0.707, 1.414]) ≈ [0.269, 0.731]
Output_1_round_1 = [[0.5, 0.5], [0.269, 0.731]] @ V[0:2]
= [[0.5, 0.5], [0.269, 0.731]]
Accumulate: Output_1 += Output_1_round_1
= [[0.69, 1.38], [0.142, 0.858]] + [[0.5, 0.5], [0.269, 0.731]]
= [[1.19, 1.88], [0.411, 1.589]]
Round 2 & 3: Continue Circulation
KV chunks continue circulating. After Round 1:
- GPU 0 has received KV[2:4] (from GPU 3, originally from GPU 1)
- GPU 1 has received KV[3:4] (from GPU 0)
- And so on…
In Rounds 2 and 3, each GPU computes with the remaining KV chunks it hasn’t yet processed.
After all 4 rounds, every GPU has computed:
GPU 0: Q[0:2] @ KV[0:2], Q[0:2] @ KV[6:8], Q[0:2] @ KV[4:6], Q[0:2] @ KV[2:4]
GPU 1: Q[2:4] @ KV[2:4], Q[2:4] @ KV[0:2], Q[2:4] @ KV[6:8], Q[2:4] @ KV[4:6]
GPU 2: Q[4:6] @ KV[4:6], Q[4:6] @ KV[2:4], Q[4:6] @ KV[0:2], Q[4:6] @ KV[6:8]
GPU 3: Q[6:8] @ KV[6:8], Q[6:8] @ KV[4:6], Q[6:8] @ KV[2:4], Q[6:8] @ KV[0:2]
This is equivalent to:
Full Attention = Q @ K^T @ V
where Q and KV span all 8 tokens (full sequence).
Key Observations
-
No replication: No KV chunk is stored on more than one GPU at a time (after initial placement). KV chunks circulate, not duplicate.
-
Balanced communication: Each GPU sends and receives the same amount of data each round.
-
Full attention achieved: Even though Q and KV are split across GPUs, the final output is equivalent to full attention on a single GPU.
-
Deterministic ordering: GPU i always receives KV from GPU (i-1) in a predictable pattern. No random access.
-
Scalability: This pattern works for any number of GPUs P and any sequence length n (where n is divisible by P).
Numerical Verification
You could verify that the final accumulated output from Ring Attention matches what you’d get from computing full attention on all 8 tokens simultaneously on a single GPU. The online softmax trick ensures numerical correctness across blocks.
Memory Footprint
In this example:
- Single GPU (all 8 tokens): 8 × 2 = 16 values (plus intermediate buffers)
- Ring Attention, 4 GPUs: (8/4) × 2 = 4 values per GPU = 16 values total
Memory distributed across 4 GPUs: 4 values per GPU (4× reduction vs single GPU holding all 16).
This scales to real sequences: 1 million tokens on 8 GPUs = 125K tokens per GPU. Feasible on consumer hardware.