Section 06

The code: sparse MoE forward pass in NumPy

Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer 2017

6. The code — sparse MoE forward pass in NumPy

🟡 First-year college. Basic Python and NumPy. Runs free on Google Colab — no GPU needed.

We implement the complete MoE layer: gating network, top-k selection, sparse expert evaluation, auxiliary balancing loss.

import numpy as np

def top_k_gating(x, W_g, k):
    """
    Compute sparse gating weights for one token x.
    x:   (d_model,)        — token representation
    W_g: (d_model, n)      — gating weight matrix
    k:   int               — number of experts to activate
    Returns: gate_weights (n,), selected_indices (k,)
    """
    n = W_g.shape[1]
    
    logits = x @ W_g                           # raw score per expert: (n,)
    
    # Find top-k indices
    top_k_idx = np.argsort(logits)[-k:]        # indices of k largest scores
    
    # Mask all but top-k to -infinity
    masked = np.full(n, -np.inf)               # start with all -inf
    masked[top_k_idx] = logits[top_k_idx]      # fill in top-k scores
    
    # Softmax over unmasked values (e^(-inf) = 0, so masked experts vanish)
    exp_masked = np.exp(masked - masked[top_k_idx].max())  # numerical stability
    gate_weights = np.zeros(n)
    gate_weights[top_k_idx] = exp_masked[top_k_idx] / exp_masked[top_k_idx].sum()
    
    return gate_weights, top_k_idx

def expert_ffn(x, W1, W2):
    """Two-layer FFN: ReLU(x·W1)·W2 — one expert's forward pass."""
    hidden = np.maximum(0, x @ W1)            # ReLU activation
    return hidden @ W2

def moe_forward(x, W_g, experts, k):
    """
    Full MoE forward pass for a single token.
    x:       (d_model,)                  — token representation
    W_g:     (d_model, n_experts)        — gating matrix
    experts: list of (W1, W2) tuples     — one per expert
    k:       int                         — top-k
    """
    n = len(experts)
    gate_weights, active_idx = top_k_gating(x, W_g, k)
    
    d_out = experts[0][1].shape[1]             # output dimension from W2
    output = np.zeros(d_out)
    
    for i in active_idx:                       # only evaluate k experts
        W1, W2 = experts[i]
        output += gate_weights[i] * expert_ffn(x, W1, W2)
    
    return output, gate_weights

def auxiliary_balance_loss(all_gate_weights, all_hard_assignments, n_experts, alpha=0.01):
    """
    Compute the load-balancing auxiliary loss over a batch.
    all_gate_weights:    (T, n) soft gate probabilities (full softmax, before top-k)
    all_hard_assignments:(T, k) indices of selected experts per token
    """
    T = all_gate_weights.shape[0]
    
    # f_i: fraction of tokens routed to each expert (hard assignment)
    f = np.zeros(n_experts)
    for assignments in all_hard_assignments:
        for idx in assignments:
            f[idx] += 1
    f /= T                                     # normalise to fraction
    
    # p_i: mean soft probability for each expert across batch
    p = all_gate_weights.mean(axis=0)
    
    loss = alpha * n_experts * np.sum(f * p)
    return loss, f, p

# ── Demo ──────────────────────────────────────────────────────────────────────
np.random.seed(7)
d_model, n_experts, d_ff, d_out, k = 8, 6, 16, 8, 2

# Random gating matrix and expert weights (in practice, learned by training)
W_g = np.random.randn(d_model, n_experts) * 0.3
experts = [(np.random.randn(d_model, d_ff) * 0.3,
            np.random.randn(d_ff, d_out)   * 0.3)
           for _ in range(n_experts)]

# Simulate a small batch: "chai bahut garam hai" (4 tokens)
batch_tokens = np.random.randn(4, d_model) * 0.5
token_names  = ["chai", "bahut", "garam", "hai"]

# Full softmax (before masking) for the balancing loss
def full_softmax(x, W_g):
    logits = x @ W_g
    logits -= logits.max()                     # stability
    e = np.exp(logits)
    return e / e.sum()

all_gate_weights    = []
all_hard_assignments = []
outputs             = []

print("Routing decisions:")
for name, x in zip(token_names, batch_tokens):
    soft_weights = full_softmax(x, W_g)        # soft (for loss)
    out, hard_weights = moe_forward(x, W_g, experts, k)
    active = np.where(hard_weights > 0)[0]     # which experts fired
    
    all_gate_weights.append(soft_weights)
    all_hard_assignments.append(active)
    outputs.append(out)
    
    print(f"  {name:7s} → Experts {active} | weights {np.round(hard_weights[active], 3)}")

# Compute balancing loss
all_gate_weights_arr = np.array(all_gate_weights)   # (4, 6)
loss, f, p = auxiliary_balance_loss(all_gate_weights_arr, all_hard_assignments, n_experts)

print(f"\nLoad distribution (f): {np.round(f, 3)}")
print(f"Soft probabilities (p): {np.round(p, 3)}")
print(f"Auxiliary balance loss: {loss:.4f}")
print(f"\nExperts unused this batch: {list(np.where(f == 0)[0])}")

Expected output (approximate — random seed varies):

Routing decisions:
  chai    → Experts [2 4] | weights [0.531 0.469]
  bahut   → Experts [1 3] | weights [0.612 0.388]
  garam   → Experts [0 4] | weights [0.558 0.442]
  hai     → Experts [2 5] | weights [0.503 0.497]

Load distribution (f): [0.25 0.25 0.5  0.25 0.5  0.25]
Soft probabilities (p): [0.16 0.18 0.21 0.15 0.19 0.11]
Auxiliary balance loss: 0.0187

Experts unused this batch: []

Try this: Change k = 1 (top-1 routing). Notice the routing becomes “harder” — each token goes to exactly one expert, and the output is simply that expert’s value (no blending). This is what the Switch Transformer (2021) uses for simplicity and speed. Compare the auxiliary loss — does it go up or down?

Also try n_experts = 4, k = 2 and look at what fraction of experts get zero tokens in small batches. This illustrates why capacity factor and load balancing are critical in production systems.