Section 04

The Math: Discretisation and Selective State Spaces

Mamba: Linear-Time Sequence Modeling with Selective State Spaces 2023

The Math: Discretisation and Selective State Spaces

Prerequisites: Eigenvalues and Eigenvectors, Matrix Multiplication


Part 1: Continuous State Space Model

A linear time-invariant (LTI) system in continuous time:

dx/dt = A x(t) + B u(t)           (state equation)
y(t) = C x(t) + D u(t)            (output equation)

Where:
  x(t) ∈ ℝ^N           (hidden state at time t)
  u(t) ∈ ℝ             (input at time t)
  y(t) ∈ ℝ             (output at time t)
  A ∈ ℝ^(N×N)          (state transition matrix)
  B ∈ ℝ^(N×1)          (input projection)
  C ∈ ℝ^(1×N)          (output projection)
  D ∈ ℝ                (feedthrough, usually 0)

Example: Exponential Decay

Simplest case: dx/dt = -0.5 x(t) + u(t)

Solution: x(t) = e^(-0.5t) x(0) + ∫₀^t e^(-0.5(t-τ)) u(τ) dτ

Interpretation:
  - Initial state x(0) decays as e^(-0.5t) (half-life ≈ 1.4 time units)
  - Each input u(τ) contributes to the current state, weighted by e^(-0.5(t-τ))
  - Older inputs decay faster (exponential decay)

Part 2: Discretisation (Zero-Order Hold)

In practice, we process discrete sequences (tokens), not continuous signals. We need to convert the continuous system to discrete time.

Assume the input u(t) is piecewise constant over intervals [t, t+Δt):

u(t) = u_k   for t ∈ [k·Δt, (k+1)·Δt)  (u_k is constant during interval k)

The solution from t_k to t_{k+1} = t_k + Δt is:

x(t+Δt) = e^(A·Δt) x(t) + ∫_0^Δt e^(A(Δt-τ)) B u(τ) dτ

Since u is constant:
x(t+Δt) = e^(A·Δt) x(t) + (∫_0^Δt e^(A(Δt-τ)) dτ) B u(t)

Compute the integral:
∫_0^Δt e^(A(Δt-τ)) dτ

Let s = Δt - τ, then:
∫_0^Δt e^(As) ds = ∫_0^A e^s ds|... (change of variables)

Using matrix exponential properties:
∫_0^Δt e^(A(Δt-τ)) dτ = A^(-1) (e^(A·Δt) - I)

So the discrete recurrence is:
x_k+1 = e^(A·Δt) x_k + A^(-1) (e^(A·Δt) - I) B u_k

Define Discretised Matrices

Let:

Ā = e^(A·Δt)                                      (discretised state transition)
B̄ = A^(-1) (e^(A·Δt) - I) B   or   B̄ ≈ Δt·B   (discretised input matrix, approximation)

The discrete system becomes:

x_k = Ā x_{k-1} + B̄ u_k
y_k = C x_k

This is now a discrete recurrence that we can compute step by step.


Part 3: Mamba’s Selective Discretisation

Mamba modifies the discretised system to be input-dependent:

Δ_k = softplus(W_Δ u_k)         (step size depends on input, scalar)
B_k = Linear_B(u_k)              (input matrix depends on input, N-dimensional)
C_k = Linear_C(u_k)              (output matrix depends on input, N-dimensional)

Ā_k = e^(Δ_k A)                  (discretised state transition, input-dependent)
B̄_k = (Δ_k A)^(-1) (e^(Δ_k A) - I) Δ_k B_k  ≈ Δ_k B_k

x_k = Ā_k x_{k-1} + B̄_k u_k
y_k = C_k x_k

Key differences from fixed SSM:

  1. Δ_k varies per token — allows adaptive memory decay
  2. B_k varies per token — allows input-dependent importance weighting
  3. C_k varies per token — allows input-dependent readout
  4. Ā_k varies per token — a consequence of Δ_k varying

Part 4: Simplification for Computation

Computing e^(Δ_k A) and its inverse is expensive. Mamba uses approximations:

For Small Δ_k (Taylor Expansion)

e^(Δ_k A) ≈ I + Δ_k A + (Δ_k A)²/2! + ...

For small Δ_k:
e^(Δ_k A) ≈ I + Δ_k A

Then:
A^(-1) (e^(Δ_k A) - I) ≈ A^(-1) (Δ_k A) = Δ_k I

So: B̄_k ≈ Δ_k B_k

For Diagonal A (Efficient Computation)

If A is diagonal (A = diag(a₁, a₂, …, a_N)):

e^(Δ_k A) = diag(e^(Δ_k a₁), e^(Δ_k a₂), ..., e^(Δ_k a_N))

Matrix multiply is cheap:
x_k = diag(...) x_{k-1} + Δ_k B_k u_k

This is just element-wise multiplication!
O(N) instead of O(N²).

Mamba uses diagonal SSMs to keep computation efficient.


Part 5: Worked Numerical Example

Let’s trace a complete example with real numbers.

Setup

State dimension: N = 2
Sequence length: T = 3
Input sequence: u = [1.0, 0.5, 2.0]
Fixed state matrix: A = [[-0.9, 0], [0, -0.8]]  (stable, diagonal)
Fixed input matrix: B = [[1.0], [1.0]]
Fixed output matrix: C = [1.0, 1.0]
Initial state: x_0 = [0, 0]

Selective parameters (learned, for this example):
  For u_k = 1.0: Δ_k = softplus(0.5) ≈ 0.974 (smallish, moderate memory)
  For u_k = 0.5: Δ_k = softplus(-0.2) ≈ 0.626 (smaller, faster forgetting)
  For u_k = 2.0: Δ_k = softplus(1.0) ≈ 1.313 (largest, slowest forgetting)

For simplicity, use approximations:
  B̄_k ≈ Δ_k B = Δ_k [[1.0], [1.0]]
  Ā_k ≈ e^(Δ_k A)

Step 1: t = 1, u₁ = 1.0

Δ₁ = softplus(0.5) ≈ 0.974
B̄₁ = 0.974 × [[1.0], [1.0]] = [[0.974], [0.974]]

Ā₁ = e^(0.974 × [[-0.9, 0], [0, -0.8]])
   = [[e^(-0.874), 0], [0, e^(-0.779)]]
   = [[0.418, 0], [0, 0.458]]

x₁ = Ā₁ x₀ + B̄₁ u₁
   = [[0.418, 0], [0, 0.458]] × [[0], [0]] + [[0.974], [0.974]] × 1.0
   = [[0], [0]] + [[0.974], [0.974]]
   = [[0.974], [0.974]]

y₁ = C x₁ = [1.0, 1.0] × [[0.974], [0.974]] = 0.974 + 0.974 = 1.948

Step 2: t = 2, u₂ = 0.5

Δ₂ = softplus(-0.2) ≈ 0.626
B̄₂ = 0.626 × [[1.0], [1.0]] = [[0.626], [0.626]]

Ā₂ = e^(0.626 × [[-0.9, 0], [0, -0.8]])
   = [[e^(-0.563), 0], [0, e^(-0.501)]]
   = [[0.570, 0], [0, 0.606]]

x₂ = Ā₂ x₁ + B̄₂ u₂
   = [[0.570, 0], [0, 0.606]] × [[0.974], [0.974]] + [[0.626], [0.626]] × 0.5
   = [[0.570×0.974], [0.606×0.974]] + [[0.313], [0.313]]
   = [[0.555], [0.590]] + [[0.313], [0.313]]
   = [[0.868], [0.903]]

y₂ = C x₂ = [1.0, 1.0] × [[0.868], [0.903]] = 0.868 + 0.903 = 1.771

Step 3: t = 3, u₃ = 2.0

Δ₃ = softplus(1.0) ≈ 1.313
B̄₃ = 1.313 × [[1.0], [1.0]] = [[1.313], [1.313]]

Ā₃ = e^(1.313 × [[-0.9, 0], [0, -0.8]])
   = [[e^(-1.182), 0], [0, e^(-1.050)]]
   = [[0.307, 0], [0, 0.350]]

x₃ = Ā₃ x₂ + B̄₃ u₃
   = [[0.307, 0], [0, 0.350]] × [[0.868], [0.903]] + [[1.313], [1.313]] × 2.0
   = [[0.307×0.868], [0.350×0.903]] + [[2.626], [2.626]]
   = [[0.266], [0.316]] + [[2.626], [2.626]]
   = [[2.892], [2.942]]

y₃ = C x₃ = [1.0, 1.0] × [[2.892], [2.942]] = 2.892 + 2.942 = 5.834

Part 6: Interpretation

Memory Trace

t=1: x₁ = [0.974, 0.974]   (input 1.0 with Δ=0.974 encoded)
t=2: x₂ = [0.868, 0.903]   (previous state decayed by Ā₂, plus new input 0.5)
t=3: x₃ = [2.892, 2.942]   (previous state decayed by Ā₃, plus large new input 2.0)

Key observation: When u₃ = 2.0 was large, Δ₃ was also large (1.313), which meant Ā₃ had slower decay rates (0.307 and 0.350 instead of smaller values). So:

  • The important input (2.0) got a large Δ (slow decay)
  • Less important inputs got smaller Δ (fast decay)
  • The model learned to allocate memory dynamically!

(Note: In real training, the Linear layers for Δ, B, C are learned to produce these selective values.)


Part 7: Stability and Eigenvalues

The fixed matrix A should be stable — all eigenvalues should have negative real parts.

For the diagonal example:

A = [[-0.9, 0], [0, -0.8]]
Eigenvalues: λ₁ = -0.9, λ₂ = -0.8  (both negative ✓)

Both are negative, so:

  • x(t) = e^(λt) x(0) → 0 as t → ∞ (exponential decay)
  • The system is stable

If A had a positive eigenvalue, e^(λt) would grow unboundedly (unstable).

Mamba constrains A to be stable (via initialization or parameterization).


Part 8: Comparison with Fixed SSM

Fixed SSM (e.g., S4)

Ā = e^(Δ A)     (Δ is a fixed hyperparameter, e.g., 0.1)
B̄ = Δ B        (fixed)

For all tokens, same state transition Ā.
For all tokens, same input importance B̄.

Mamba (Selective SSM)

Δ_k = softplus(W_Δ u_k)    (varies per token)
Ā_k = e^(Δ_k A)              (varies per token)
B̄_k = Δ_k B_k                (varies per token, since both Δ_k and B_k vary)
C_k = Linear_C(u_k)           (varies per token)

Different tokens get different state transitions and input importance.

The fixed SSM is simpler but less expressive. Mamba’s selectivity is what makes it competitive.


Summary: Full Mamba Forward Pass

For each token t:
  1. Compute selective parameters:
     Δ_t = softplus(W_Δ u_t)
     B_t = Linear_B(u_t)
     C_t = Linear_C(u_t)
  
  2. Discretise:
     Ā_t = exp(Δ_t A)
     B̄_t = Δ_t B_t  (approximation)
  
  3. Update state:
     x_t = Ā_t x_{t-1} + B̄_t u_t
  
  4. Compute output:
     y_t = C_t x_t

Total per-token complexity: O(N²) for matrix multiply, or O(N) with diagonal A

Next: Worked Example: Step-by-Step Trace