Section 03

The Idea: Making SSMs Input-Dependent

Mamba: Linear-Time Sequence Modeling with Selective State Spaces 2023

The Idea: Making SSMs Input-Dependent

The Core Innovation

Instead of this (fixed SSM):

x_t = A x_{t-1} + B u_t          (A, B fixed for all t)
y_t = C x_t

Mamba does this (selective SSM):

Δ_t = softplus(W_Δ u_t)          (step size depends on input)
B_t = Linear(W_B u_t)             (input matrix depends on input)
C_t = Linear(W_C u_t)             (output matrix depends on input)

A_bar_t = exp(Δ_t A)              (discretized, input-dependent)
B_bar_t = Δ_t B_t                 (scaled input matrix)

x_t = A_bar_t x_{t-1} + B_bar_t u_t
y_t = C_t x_t

Now the model decides, for each token:

  • Δ_t: How much to “stretch” or “compress” time (large Δ = slower decay)
  • B_t: How much this token should influence the state
  • C_t: How to read out the state

Intuition: Selective Memory

Think of a student reading a textbook:

Fixed SSM student:

  • Reads each word at a constant pace
  • Forgets old information at a constant rate
  • Can’t speed up for important parts or skim through filler

Mamba student:

  • Sees the word “Definition:” → slows down (Δ increases, slower decay)
  • Sees “for example” → speeds up (Δ decreases, faster decay)
  • Important concepts → remember longer (high B)
  • Filler words → weak signal (low B)
  • Adjust reading emphasis based on content (C)

The student now matches or beats the attentive reader (Transformer) because it’s selective, not because it reads everything carefully.


Mathematical Details

Step 1: Compute Selective Parameters

For each input u_t, compute:

Δ_t = softplus(W_Δ u_t)           ∈ ℝ     (scalar, minimum ~0, no upper bound)
B_t = Linear_B(u_t)                ∈ ℝ^N   (N-dimensional, for state size N)
C_t = Linear_C(u_t)                ∈ ℝ^N   (N-dimensional, for output)

Where:
  W_Δ ∈ ℝ^(1 × d)    (projects input to step size)
  Linear_B, Linear_C are learned layers
  softplus(x) = log(1 + exp(x))    (smooth approximation of max(x, ε))

Why softplus? It’s smooth (differentiable everywhere) and maps to (≈0, ∞), suitable for time-step sizes.

Step 2: Discretization

The continuous SSM has a fixed A matrix (e.g., A = -I, which causes exponential decay). Discretize using the step size Δ_t:

Discretized: x_t = A_bar_t x_{t-1} + B_bar_t u_t

Where:
  A_bar_t = exp(Δ_t A)                     ∈ ℝ^(N×N)
  B_bar_t = (Δ_t A)^(-1) (exp(Δ_t A) - I) Δ_t B_t  ≈ Δ_t B_t  (approximation)

For small Δ_t:
  A_bar_t ≈ I + Δ_t A                      (matrix exponential Taylor expansion)
  B_bar_t ≈ Δ_t B_t

Key insight: A stays fixed (learned), but A_bar_t changes per token because Δ_t changes.

Step 3: State Update (Recurrent)

x_t = A_bar_t x_{t-1} + B_bar_t u_t    (standard recurrence)

This is O(N²) per step if computed naively (matrix multiplication). But see Section 4 for optimization.

Step 4: Output

y_t = C_t x_t                          (element-wise product, O(N))

Why This Works: The Selectivity Effect

Example: Long-Range Dependency

Suppose we want the model to remember token 0 up to token T.

Fixed SSM:

x_t = A^t x_0 + (stuff from recent tokens)

If A has eigenvalues with magnitude < 1 (stable):
  x_T ≈ A^T x_0    (decayed exponentially)
  If A = -0.1, then (-0.1)^T ≈ 0 for T > 50

Mamba (selective):

At token 0 (important): Δ_0 is large → A_bar_0 ≈ exp(large A) → slower decay
At tokens 1-T (filler): Δ_t is small → A_bar_t ≈ exp(small A) → faster decay

x_T ≈ (small A_bar) × ... × (small A_bar) × (large A_bar) × x_0
    ≈ (product of many small scalars) × memory_of_token_0
    
But here's the trick: Δ values are learned!
If token 0 matters, the model learns to set Δ_0 very large.
If tokens 1-T don't matter, the model learns to set Δ_t small.
Result: Mamba remembers token 0 despite the distance.

Example: Input-Dependent Importance

Sequence: “Alice lives in Paris. It was a beautiful day. Where does Alice live?”

Token "Paris" (important):
  u_t = "Paris"
  B_t = Linear_B("Paris") = [high, high, high, ...]   (large values)
  → x_t = A_bar_t x_{t-1} + [large, large, ...] × "Paris"
  → "Paris" strongly encoded in state

Token "beautiful" (filler):
  u_t = "beautiful"
  B_t = Linear_B("beautiful") = [low, low, low, ...]   (small values)
  → x_t = A_bar_t x_{t-1} + [small, small, ...] × "beautiful"
  → "beautiful" weakly encoded in state

Model learns this selectivity from data!

Hardware Efficiency: The Parallel Scan Trick

Naively, computing x_t requires O(N²) work (matrix multiply). For a sequence of length n, that’s O(n N²) total — slow.

But there’s a trick: parallel scan algorithm (also called parallel prefix scan).

Training: Use Convolution

During training, Mamba reformulates the recurrence as a convolution:

x_t = A_bar x_{t-1} + B_bar u_t

Can be rewritten as:

y = conv(A_bar, B_bar * u)    (using FFT, O(n log n) with FFT)

This allows parallelization across time steps during training, despite the recurrent structure.

Inference: Use Recurrence with Fused Kernels

During inference, we generate one token at a time, so the recurrence structure is unavoidable. But:

Standard way:
  x_t = exp(Δ_t A) @ x_{t-1} + Δ_t B_t @ u_t
  (requires two matrix multiplications per step)

Fused kernel (hardware-aware):
  x_t = fused_kernel(A, Δ_t, B_t, u_t, x_{t-1})
  (one optimized kernel call, minimizing memory movement)
  
Memory is the bottleneck on modern GPUs, not compute!
Fusing operations reduces memory traffic dramatically.

Result: Inference is 5× faster than Transformers for long sequences (2K+ tokens).


Architecture Recap

Input u_t
    |
    ├─→ Linear (d_model → 2×d_model)  [expand for B, C]
    ├─→ Linear (d_model → 1) → softplus → Δ_t
    |
    ├─→ SiLU activation (gating)
    |
    └─→ B_t, C_t computed from expanded features
    
    State x is maintained recurrently:
    x_t = exp(Δ_t A) @ x_{t-1} + Δ_t B_t @ u_t
    
    Output: y_t = (C_t @ x_t) * (gate(u_t))
    
    y_t is projected back to d_model (via another Linear layer)

The gating (SiLU activation) is inspired by gating in RNNs and LSTMs — it allows the model to control information flow.


The Trade-off

What Mamba Gains:

  • ✓ O(n) inference time (no O(n²) attention)
  • ✓ O(1) memory per step (no KV cache)
  • ✓ Selectivity (input-dependent parameters)

What Mamba Sacrifices:

  • ✗ Less expressive attention (can’t attend backwards with full flexibility)
  • ✗ Harder to implement (needs custom kernels)
  • ✗ Recurrent structure (harder to parallelize than attention)

The Bet: Selectivity + efficiency > full attention flexibility, for most tasks.


Comparison with Transformers

AspectMambaTransformer
Forward passRecurrent (sequential)Parallel (batched)
Memory per stepO(N) state sizeO(n·d) for KV cache
Inference latencyO(n) with O(1) per stepO(n) but parallelizable to O(1) per token with KV cache
Long-context speed5× faster at 2K tokens, 10× at 64KStandard baseline
Context length limitNone (only limited by compute)Practical limit ~128K (memory)
In-context recallStruggles at very long rangeExcellent (full attention)
Training speedSimilar to TransformersSimilar to Mamba

Next: The Math: Discretisation and Continuous Systems