Section 03

The Idea: Ring Topology and Pipelined Communication

Ring Attention with Blockwise Transformers for Near-Infinite Context 2023

The Idea: Ring Topology and Pipelined Communication

Ring Attention’s genius is surprisingly simple: arrange GPUs in a ring and pass KV chunks from device to device in a coordinated, overlapping fashion.

Core Concept: The Ring Topology

Imagine P GPUs arranged in a ring:

        GPU 0
       /     \
    GPU 3   GPU 1
       \     /
        GPU 2

Each GPU i can communicate with GPU (i-1) and GPU (i+1). You can’t directly send to GPU (i+2) without going through intermediate devices.

But this simplicity is the strength: each GPU only needs to synchronise with two neighbours, not all P devices.

The Key Insight: Blockwise Attention

Instead of computing Q @ K^T / √d_head for all n tokens at once, compute it in blocks:

  • Block 0: Q[0:n/P] @ K[0:n/P]^T
  • Block 1: Q[n/P:2n/P] @ K[0:n/P]^T
  • Block 2: Q[n/P:2n/P] @ K[n/P:2n/P]^T

Each block is small enough to fit on one GPU. You compute Q_block × KV_block, get scores, apply softmax, and combine.

Why blockwise works: Standard attention is mathematically associative at the block level. You can compute block-by-block and combine results (with careful numerical handling).

How Ring Attention Works (Algorithm)

Setup:

  • P GPUs, GPU i holds:
    • Query chunk Q[i]: tokens i×(n/P) to (i+1)×(n/P)
    • KV chunk KV[i]: keys and values for the same token range

Algorithm (pseudocode):

for round in range(P):
    # In round r, GPU i attends to KV from GPU (i - r) % P
    kv_chunk_idx = (i - round) % P
    
    # Parallel tasks (compute + communication):
    Task 1 (Compute): GPU i computes Q[i] @ KV[kv_chunk_idx]^T
                      Applies softmax and combines with V[kv_chunk_idx]
                      Accumulates results
    
    Task 2 (Communication): Simultaneously,
                            GPU i sends KV[i] to GPU (i + 1) % P
                            GPU i receives KV[kv_chunk_idx-1] from GPU (i-1) % P
    
    Barrier: Wait for all P GPUs to finish round r

Output: Each GPU i has accumulated attention output for all queries

The magic: Task 1 (compute) and Task 2 (communication) happen in parallel. While GPU i is computing, data is flowing. By the time it finishes computation, the next KV chunk has arrived (or is arriving).

Visualising the Flow (Round-by-Round)

Initial state (round 0):

GPU 0 has: Q[0], KV[0]
GPU 1 has: Q[1], KV[1]
GPU 2 has: Q[2], KV[2]
GPU 3 has: Q[3], KV[3]

Round 0:

GPU 0: Compute Q[0] @ KV[0] | Send KV[0]→GPU 1, Recv KV[3]←GPU 3
GPU 1: Compute Q[1] @ KV[1] | Send KV[1]→GPU 2, Recv KV[0]←GPU 0
GPU 2: Compute Q[2] @ KV[2] | Send KV[2]→GPU 3, Recv KV[1]←GPU 1
GPU 3: Compute Q[3] @ KV[3] | Send KV[3]→GPU 0, Recv KV[2]←GPU 2

Round 1:

Now each GPU has moved one position around the ring. KV chunks continue circulating.

GPU 0: Compute Q[0] @ KV[3] | Send KV[3]→GPU 1, Recv KV[2]←GPU 3
GPU 1: Compute Q[1] @ KV[0] | Send KV[0]→GPU 2, Recv KV[3]←GPU 0
GPU 2: Compute Q[2] @ KV[1] | Send KV[1]→GPU 3, Recv KV[0]←GPU 1
GPU 3: Compute Q[3] @ KV[2] | Send KV[2]→GPU 0, Recv KV[1]←GPU 2

After P rounds: Each GPU has computed Q @ KV for all KV chunks. Full attention achieved.

Compute-Communication Overlap (The Secret Sauce)

This is why Ring Attention is faster than naive communication patterns:

Without pipelining (synchronous):

GPU 0: Compute (1 second) → Wait for communication (2 seconds) → Compute (1 second) → ...
       Utilisation = 1/3 = 33%

With pipelining (Ring Attention):

GPU 0: [Compute 1s] [Compute 1s + Comm in parallel] [Compute 1s + Comm in parallel]
       Utilisation = ~100% (compute fully hides communication)

If compute time ≈ communication time (which we design for), communication latency disappears. The network is “free.”

Blockwise Computation: Online Softmax

To make blockwise attention numerically correct, use online softmax (log-sum-exp trick):

Standard softmax (unstable):

softmax(x) = exp(x) / sum(exp(x))

Problem: If you apply softmax to block 1, then block 2, they don’t combine correctly.

Online softmax:

For each block b:
  m_b = max(Q @ KV[b]^T / √d)  (per-block maximum)
  l_b = sum(exp(Q @ KV[b]^T / √d - m_b))  (per-block normalization)
  o_b = (exp(...) - m_b) @ V[b]

Then, combine blocks:
  m_final = max(m_1, m_2, ...)
  l_final = sum(l_b × exp(m_b - m_final))
  o_final = sum(o_b × exp(m_b - m_final)) / l_final

This is numerically stable and gives the exact same result as computing all blocks at once then applying softmax.

Memory Scaling

With P GPUs:

  • Each GPU stores Q chunk: (n/P) × d floats
  • Each GPU stores KV chunk: 2 × (n/P) × d floats (at any moment)
  • Total per GPU: O((n/P) × d)

Compare to single-GPU full attention: O(n × d) Scaling factor: P× reduction in per-GPU memory.

For 1 million tokens with 8 GPUs: 1M / 8 = 125K tokens per GPU. Feasible.

Communication Cost

Each KV chunk makes P passes around the ring. Each pass involves:

  • Sending (n/P) × d values to next GPU
  • Receiving (n/P) × d values from previous GPU

Total communication per round: O((n/P) × d) per GPU P rounds: O(n × d) total per GPU

For 1M tokens with 8 GPUs, d=4096:

Comm per round = (1M/8) × 4096 × 2 bytes = 1 GB per GPU
P rounds = 8 rounds × 1 GB = 8 GB total data moved

Network capable of 200 GB/s: 8 GB / 200 GB/s = 0.04 seconds
Compute time for one round: ~1 second (rough estimate)

Communication is ~4% of total time — hidden by compute overlap.

The Indian Analogy

Imagine a relay race with 8 runners around a circular track. Each runner holds a piece of a puzzle.

Naive approach: All runners stop and gather in the center, compare puzzle pieces, then move on.

  • Slow: everyone waits for everyone else
  • Boring: idle time between comparisons

Ring Attention approach: Runners keep running around the track. As they pass each other, they exchange puzzle pieces (hand-off is quick). Each runner, while running, pieces together the puzzle with the pieces they’ve collected.

  • Fast: no stop-and-wait
  • Efficient: handoff (communication) is hidden by running (compute)
  • After 8 laps, every runner has seen every piece and assembled the full puzzle

This is compute-communication overlap.

Comparison to Alternatives

Mistral SWA: Limited to local context, no true long-range attention

Ring Attention: True full attention over arbitrarily long sequences

Sparse Attention: Some attention patterns are ignored (lossy)

Ring Attention: Full attention, numerically exact

Gradient Checkpointing: Saves memory only during training; doesn’t help inference

Ring Attention: Helps both training and inference

The trade-off: Ring Attention requires multiple GPUs and adds implementation complexity. But it’s the only approach that scales both memory and computation cleanly while maintaining full attention.