6. The code — a minimal LSTM cell in PyTorch
Runs as-is on a free Google Colab CPU. No GPU needed, no data download. Open a new Colab notebook, paste this in, and hit run.
The cell (25 lines)
import torch # PyTorch — tensors + autograd
import torch.nn as nn # the neural-network module
class MiniLSTMCell(nn.Module): # our LSTM cell, from scratch
def __init__(self, input_size, hidden_size): # define shapes
super().__init__() # standard nn.Module boilerplate
self.hidden_size = hidden_size # remember hidden size for zeros
# one big linear layer producing all 4 gate pre-activations at once
self.gates = nn.Linear(input_size + hidden_size, 4 * hidden_size)
def forward(self, x, state): # x: (batch, input_size)
h_prev, c_prev = state # unpack previous hidden + cell
combined = torch.cat([x, h_prev], dim=1) # join input and prev hidden
z = self.gates(combined) # 4 gate pre-activations
z_f, z_i, z_c, z_o = z.chunk(4, dim=1) # split into four equal parts
f = torch.sigmoid(z_f) # forget gate
i = torch.sigmoid(z_i) # input gate
c_tilde = torch.tanh(z_c) # candidate content
o = torch.sigmoid(z_o) # output gate
c = f * c_prev + i * c_tilde # cell state update (eq 4)
h = o * torch.tanh(c) # hidden state (eq 6)
return h, (h, c) # return output + new state
Exactly 25 code lines. Nothing hidden.
A tiny test run
Append this to the same notebook cell (or add a second cell):
cell = MiniLSTMCell(input_size=3, hidden_size=2) # 3-dim input, 2-dim memory
x = torch.randn(1, 3) # one fake input vector
h = torch.zeros(1, 2) # initial hidden state
c = torch.zeros(1, 2) # initial cell state
out, (h, c) = cell(x, (h, c)) # run one step
print("hidden:", out)
print("cell :", c)
You should see something like:
hidden: tensor([[-0.0421, 0.1137]], grad_fn=<MulBackward0>)
cell : tensor([[-0.0952, 0.2334]], grad_fn=<AddBackward0>)
The exact numbers differ every run because the weights are randomly initialised — that’s expected. What matters is that it runs without error and the shapes are right: 1×2 for a batch of 1 and hidden size 2.
Unroll it over a sequence
This is where LSTMs earn their name. Let’s feed it a sequence of length 5:
seq = torch.randn(5, 1, 3) # 5 time steps, batch 1, dim 3
h = torch.zeros(1, 2) # reset hidden state
c = torch.zeros(1, 2) # reset cell state
for t, x_t in enumerate(seq): # loop over time
out, (h, c) = cell(x_t, (h, c)) # one LSTM step
print(f"step {t}: cell = {c.detach().numpy()}")
Watch the cell state drift from [0, 0] as new information is written
and old information is partially kept. That drift is the network’s
memory forming. In a real training loop, the gates would learn to keep
useful memory and discard noise.
Compare to PyTorch’s built-in
Once you understand this, use the library version — it’s faster and supports batching, CUDA, dropout, etc.:
lstm = nn.LSTM(input_size=3, hidden_size=2, batch_first=True)
seq_batch = torch.randn(1, 5, 3) # (batch, seq_len, input)
output, (h_final, c_final) = lstm(seq_batch)
print("output shape:", output.shape) # (1, 5, 2)
nn.LSTM is the same math we just wrote, vectorised in C++ and bound
to Python. Under the hood it is running our six equations per step.
Exercises to try on Colab
- Change
hidden_sizeto 8. How does the cell state shape change? - Initialise the forget-gate bias to +2 (a classic LSTM trick). Why might this help in the early stages of training? (Hint: σ(2) ≈ 0.88. The network starts by mostly keeping memory, rather than erasing it randomly. This came out of Jozefowicz et al., 2015 — a paper we will cover later.)
- Run the cell 100 times on random input. Plot
cover time withmatplotlib. Does it explode, drift, or stabilise?
What to take away from the code
The whole LSTM cell is one Linear layer and a handful of element-wise
operations. It is not intricate. It is not deep. What made LSTMs work
was not engineering complexity — it was the idea that memory should
be separate from output and controlled by learned gates. The equations
from Section 5 are the whole story, and these 25 lines are their
literal translation to PyTorch.