6. The code — sparse MoE forward pass in NumPy
🟡 First-year college. Basic Python and NumPy. Runs free on Google Colab — no GPU needed.
We implement the complete MoE layer: gating network, top-k selection, sparse expert evaluation, auxiliary balancing loss.
import numpy as np
def top_k_gating(x, W_g, k):
"""
Compute sparse gating weights for one token x.
x: (d_model,) — token representation
W_g: (d_model, n) — gating weight matrix
k: int — number of experts to activate
Returns: gate_weights (n,), selected_indices (k,)
"""
n = W_g.shape[1]
logits = x @ W_g # raw score per expert: (n,)
# Find top-k indices
top_k_idx = np.argsort(logits)[-k:] # indices of k largest scores
# Mask all but top-k to -infinity
masked = np.full(n, -np.inf) # start with all -inf
masked[top_k_idx] = logits[top_k_idx] # fill in top-k scores
# Softmax over unmasked values (e^(-inf) = 0, so masked experts vanish)
exp_masked = np.exp(masked - masked[top_k_idx].max()) # numerical stability
gate_weights = np.zeros(n)
gate_weights[top_k_idx] = exp_masked[top_k_idx] / exp_masked[top_k_idx].sum()
return gate_weights, top_k_idx
def expert_ffn(x, W1, W2):
"""Two-layer FFN: ReLU(x·W1)·W2 — one expert's forward pass."""
hidden = np.maximum(0, x @ W1) # ReLU activation
return hidden @ W2
def moe_forward(x, W_g, experts, k):
"""
Full MoE forward pass for a single token.
x: (d_model,) — token representation
W_g: (d_model, n_experts) — gating matrix
experts: list of (W1, W2) tuples — one per expert
k: int — top-k
"""
n = len(experts)
gate_weights, active_idx = top_k_gating(x, W_g, k)
d_out = experts[0][1].shape[1] # output dimension from W2
output = np.zeros(d_out)
for i in active_idx: # only evaluate k experts
W1, W2 = experts[i]
output += gate_weights[i] * expert_ffn(x, W1, W2)
return output, gate_weights
def auxiliary_balance_loss(all_gate_weights, all_hard_assignments, n_experts, alpha=0.01):
"""
Compute the load-balancing auxiliary loss over a batch.
all_gate_weights: (T, n) soft gate probabilities (full softmax, before top-k)
all_hard_assignments:(T, k) indices of selected experts per token
"""
T = all_gate_weights.shape[0]
# f_i: fraction of tokens routed to each expert (hard assignment)
f = np.zeros(n_experts)
for assignments in all_hard_assignments:
for idx in assignments:
f[idx] += 1
f /= T # normalise to fraction
# p_i: mean soft probability for each expert across batch
p = all_gate_weights.mean(axis=0)
loss = alpha * n_experts * np.sum(f * p)
return loss, f, p
# ── Demo ──────────────────────────────────────────────────────────────────────
np.random.seed(7)
d_model, n_experts, d_ff, d_out, k = 8, 6, 16, 8, 2
# Random gating matrix and expert weights (in practice, learned by training)
W_g = np.random.randn(d_model, n_experts) * 0.3
experts = [(np.random.randn(d_model, d_ff) * 0.3,
np.random.randn(d_ff, d_out) * 0.3)
for _ in range(n_experts)]
# Simulate a small batch: "chai bahut garam hai" (4 tokens)
batch_tokens = np.random.randn(4, d_model) * 0.5
token_names = ["chai", "bahut", "garam", "hai"]
# Full softmax (before masking) for the balancing loss
def full_softmax(x, W_g):
logits = x @ W_g
logits -= logits.max() # stability
e = np.exp(logits)
return e / e.sum()
all_gate_weights = []
all_hard_assignments = []
outputs = []
print("Routing decisions:")
for name, x in zip(token_names, batch_tokens):
soft_weights = full_softmax(x, W_g) # soft (for loss)
out, hard_weights = moe_forward(x, W_g, experts, k)
active = np.where(hard_weights > 0)[0] # which experts fired
all_gate_weights.append(soft_weights)
all_hard_assignments.append(active)
outputs.append(out)
print(f" {name:7s} → Experts {active} | weights {np.round(hard_weights[active], 3)}")
# Compute balancing loss
all_gate_weights_arr = np.array(all_gate_weights) # (4, 6)
loss, f, p = auxiliary_balance_loss(all_gate_weights_arr, all_hard_assignments, n_experts)
print(f"\nLoad distribution (f): {np.round(f, 3)}")
print(f"Soft probabilities (p): {np.round(p, 3)}")
print(f"Auxiliary balance loss: {loss:.4f}")
print(f"\nExperts unused this batch: {list(np.where(f == 0)[0])}")
Expected output (approximate — random seed varies):
Routing decisions:
chai → Experts [2 4] | weights [0.531 0.469]
bahut → Experts [1 3] | weights [0.612 0.388]
garam → Experts [0 4] | weights [0.558 0.442]
hai → Experts [2 5] | weights [0.503 0.497]
Load distribution (f): [0.25 0.25 0.5 0.25 0.5 0.25]
Soft probabilities (p): [0.16 0.18 0.21 0.15 0.19 0.11]
Auxiliary balance loss: 0.0187
Experts unused this batch: []
Try this: Change k = 1 (top-1 routing). Notice the routing becomes “harder” — each token goes to exactly one expert, and the output is simply that expert’s value (no blending). This is what the Switch Transformer (2021) uses for simplicity and speed. Compare the auxiliary loss — does it go up or down?
Also try n_experts = 4, k = 2 and look at what fraction of experts get zero tokens in small batches. This illustrates why capacity factor and load balancing are critical in production systems.