5. Worked example — routing a batch of tokens through 4 experts
🔴 Advanced undergrad. We trace a batch of 4 tokens through a 4-expert MoE layer with k=2, then compute the auxiliary balancing loss.
Setup
4 experts, 2-dimensional token representations, k = 2.
The four tokens from the sentence “chai bahut garam hai” (“chai is very hot”):
x₁ = [1.0, 0.2] ("chai" — a concrete noun, drink)
x₂ = [0.3, 0.8] ("bahut" — an adverb, degree)
x₃ = [0.1, 0.5] ("garam" — an adjective, temperature)
x₄ = [0.6, 0.1] ("hai" — a verb, copula)
Gating weight matrix W_g (2 dimensions → 4 expert scores):
W_g = [[ 1.0, 0.5, -0.5, 0.2], ← weights for dim 1
[-0.2, 0.8, 1.0, -0.3]] ← weights for dim 2
shape: (2 × 4)
Expert networks (each is a simple linear map, for illustration):
Expert 1 (E₁): specialises in nouns/concrete objects
W_E1 = [[1.2, 0.0], [0.0, 0.5]] → E₁(x) = x · W_E1
Expert 2 (E₂): specialises in degree/quantifier words
W_E2 = [[0.3, 0.0], [0.0, 1.4]] → E₂(x) = x · W_E2
Expert 3 (E₃): specialises in descriptive/adjective words
W_E3 = [[0.2, 0.8], [0.9, 0.1]] → E₃(x) = x · W_E3
Expert 4 (E₄): specialises in function words/verbs
W_E4 = [[0.7, 0.3], [0.1, 0.6]] → E₄(x) = x · W_E4
Step 1: Compute raw gating logits for all tokens
h(xₜ) = xₜ · W_g for each token:
Token 1 "chai" x₁ = [1.0, 0.2]:
h(x₁) = [1.0×1.0 + 0.2×(−0.2), 1.0×0.5 + 0.2×0.8, 1.0×(−0.5) + 0.2×1.0, 1.0×0.2 + 0.2×(−0.3)]
= [1.0−0.04, 0.5+0.16, −0.5+0.20, 0.2−0.06]
= [0.96, 0.66, −0.30, 0.14]
Token 2 "bahut" x₂ = [0.3, 0.8]:
h(x₂) = [0.3×1.0 + 0.8×(−0.2), 0.3×0.5 + 0.8×0.8, 0.3×(−0.5) + 0.8×1.0, 0.3×0.2 + 0.8×(−0.3)]
= [0.30−0.16, 0.15+0.64, −0.15+0.80, 0.06−0.24]
= [0.14, 0.79, 0.65, −0.18]
Token 3 "garam" x₃ = [0.1, 0.5]:
h(x₃) = [0.1×1.0 + 0.5×(−0.2), 0.1×0.5 + 0.5×0.8, 0.1×(−0.5) + 0.5×1.0, 0.1×0.2 + 0.5×(−0.3)]
= [0.10−0.10, 0.05+0.40, −0.05+0.50, 0.02−0.15]
= [0.00, 0.45, 0.45, −0.13]
Token 4 "hai" x₄ = [0.6, 0.1]:
h(x₄) = [0.6×1.0 + 0.1×(−0.2), 0.6×0.5 + 0.1×0.8, 0.6×(−0.5) + 0.1×1.0, 0.6×0.2 + 0.1×(−0.3)]
= [0.60−0.02, 0.30+0.08, −0.30+0.10, 0.12−0.03]
= [0.58, 0.38, −0.20, 0.09]
Step 2: Apply Top-2 selection (k=2)
Keep the 2 highest scores per token, set others to −∞:
Token 1 "chai": scores [0.96, 0.66, −0.30, 0.14]
Top 2: Expert 1 (0.96) ✓, Expert 2 (0.66) ✓
After TopK: [0.96, 0.66, −∞, −∞] → Experts 3, 4 skipped
Token 2 "bahut": scores [0.14, 0.79, 0.65, −0.18]
Top 2: Expert 2 (0.79) ✓, Expert 3 (0.65) ✓
After TopK: [−∞, 0.79, 0.65, −∞] → Experts 1, 4 skipped
Token 3 "garam": scores [0.00, 0.45, 0.45, −0.13]
Top 2: Expert 2 (0.45) ✓, Expert 3 (0.45) ✓ (tie — pick first two)
After TopK: [−∞, 0.45, 0.45, −∞] → Experts 1, 4 skipped
Token 4 "hai": scores [0.58, 0.38, −0.20, 0.09]
Top 2: Expert 1 (0.58) ✓, Expert 2 (0.38) ✓
After TopK: [0.58, 0.38, −∞, −∞] → Experts 3, 4 skipped
Step 3: Compute gating weights (softmax over top-2)
Token 1 "chai": softmax([0.96, 0.66])
exp(0.96)=2.612, exp(0.66)=1.935, sum=4.547
G(x₁) = [0.574, 0.426, 0.000, 0.000]
Token 2 "bahut": softmax([0.79, 0.65])
exp(0.79)=2.203, exp(0.65)=1.916, sum=4.119
G(x₂) = [0.000, 0.535, 0.465, 0.000]
Token 3 "garam": softmax([0.45, 0.45])
exp(0.45)=1.568, exp(0.45)=1.568, sum=3.136
G(x₃) = [0.000, 0.500, 0.500, 0.000]
Token 4 "hai": softmax([0.58, 0.38])
exp(0.58)=1.786, exp(0.38)=1.462, sum=3.248
G(x₄) = [0.550, 0.450, 0.000, 0.000]
Step 4: Compute MoE outputs
Token 1 “chai” — Experts 1 and 2 activated:
E₁(x₁) = x₁ · W_E1 = [1.0, 0.2] · [[1.2, 0.0], [0.0, 0.5]]
= [1.0×1.2 + 0.2×0.0, 1.0×0.0 + 0.2×0.5]
= [1.200, 0.100]
E₂(x₁) = x₁ · W_E2 = [1.0, 0.2] · [[0.3, 0.0], [0.0, 1.4]]
= [0.300, 0.280]
MoE(x₁) = 0.574×[1.200, 0.100] + 0.426×[0.300, 0.280]
= [0.689, 0.057] + [0.128, 0.119]
= [0.817, 0.176]
Expert 1 (noun specialist) dominates “chai” with 57.4% weight — sensible, “chai” is a concrete noun.
Token 2 “bahut” — Experts 2 and 3 activated:
E₂(x₂) = [0.3, 0.8] · [[0.3,0],[0,1.4]] = [0.090, 1.120]
E₃(x₂) = [0.3, 0.8] · [[0.2,0.8],[0.9,0.1]] = [0.3×0.2+0.8×0.9, 0.3×0.8+0.8×0.1]
= [0.060+0.720, 0.240+0.080] = [0.780, 0.320]
MoE(x₂) = 0.535×[0.090, 1.120] + 0.465×[0.780, 0.320]
= [0.048, 0.599] + [0.363, 0.149]
= [0.411, 0.748]
Token 4 “hai” — Experts 1 and 2 activated (same experts as “chai” but different weights):
E₁(x₄) = [0.6, 0.1] · [[1.2,0],[0,0.5]] = [0.720, 0.050]
E₂(x₄) = [0.6, 0.1] · [[0.3,0],[0,1.4]] = [0.180, 0.140]
MoE(x₄) = 0.550×[0.720, 0.050] + 0.450×[0.180, 0.140]
= [0.396, 0.028] + [0.081, 0.063]
= [0.477, 0.091]
Step 5: Compute auxiliary balancing loss
Count how many times each expert was selected (hard routing) across all 4 tokens:
Expert 1: tokens 1, 4 → f₁ = 2/4 = 0.500
Expert 2: tokens 1, 2, 3, 4 → f₂ = 4/4 = 1.000
Expert 3: tokens 2, 3 → f₃ = 2/4 = 0.500
Expert 4: no tokens → f₄ = 0/4 = 0.000
Expert 2 is getting all 4 tokens! Expert 4 gets none. This is a mild expert collapse scenario.
Soft gating probabilities p (averages of full softmax across all 4 tokens, approximate):
p₁ ≈ 0.310, p₂ ≈ 0.470, p₃ ≈ 0.170, p₄ ≈ 0.050
Auxiliary loss (with n=4, α=0.01):
L_balance = 0.01 × 4 × (f₁p₁ + f₂p₂ + f₃p₃ + f₄p₄)
= 0.04 × (0.500×0.310 + 1.000×0.470 + 0.500×0.170 + 0.000×0.050)
= 0.04 × (0.155 + 0.470 + 0.085 + 0.000)
= 0.04 × 0.710
= 0.0284
This non-trivial loss (compared to perfect balance which would give ~0.04) generates a gradient that pushes W_g to route fewer tokens to Expert 2 and start sending some to Expert 4. Over many batches, the distribution balances out.
Summary of token routing
"chai" (noun) → Expert 1 (57%) + Expert 2 (43%)
"bahut" (adverb) → Expert 2 (54%) + Expert 3 (47%)
"garam" (adjective)→ Expert 2 (50%) + Expert 3 (50%)
"hai" (verb) → Expert 1 (55%) + Expert 2 (45%)
Even in this tiny toy example with hand-crafted weights, the routing is semantically sensible: the noun-specialist Expert 1 handles the noun “chai” and the function word “hai,” while the descriptive-word Experts 2–3 handle “bahut” and “garam.” In a real trained model with 1,000 experts and millions of training steps, the specialists become far more nuanced.