Section 06

The code

Long Short-Term Memory 1997

6. The code — a minimal LSTM cell in PyTorch

Runs as-is on a free Google Colab CPU. No GPU needed, no data download. Open a new Colab notebook, paste this in, and hit run.

The cell (25 lines)

import torch                                     # PyTorch — tensors + autograd
import torch.nn as nn                            # the neural-network module

class MiniLSTMCell(nn.Module):                   # our LSTM cell, from scratch
    def __init__(self, input_size, hidden_size): # define shapes
        super().__init__()                       # standard nn.Module boilerplate
        self.hidden_size = hidden_size           # remember hidden size for zeros
        # one big linear layer producing all 4 gate pre-activations at once
        self.gates = nn.Linear(input_size + hidden_size, 4 * hidden_size)

    def forward(self, x, state):                 # x: (batch, input_size)
        h_prev, c_prev = state                   # unpack previous hidden + cell
        combined = torch.cat([x, h_prev], dim=1) # join input and prev hidden
        z = self.gates(combined)                 # 4 gate pre-activations
        z_f, z_i, z_c, z_o = z.chunk(4, dim=1)   # split into four equal parts
        f = torch.sigmoid(z_f)                   # forget gate
        i = torch.sigmoid(z_i)                   # input gate
        c_tilde = torch.tanh(z_c)                # candidate content
        o = torch.sigmoid(z_o)                   # output gate
        c = f * c_prev + i * c_tilde             # cell state update (eq 4)
        h = o * torch.tanh(c)                    # hidden state (eq 6)
        return h, (h, c)                         # return output + new state

Exactly 25 code lines. Nothing hidden.

A tiny test run

Append this to the same notebook cell (or add a second cell):

cell = MiniLSTMCell(input_size=3, hidden_size=2) # 3-dim input, 2-dim memory
x = torch.randn(1, 3)                            # one fake input vector
h = torch.zeros(1, 2)                            # initial hidden state
c = torch.zeros(1, 2)                            # initial cell state
out, (h, c) = cell(x, (h, c))                    # run one step
print("hidden:", out)
print("cell  :", c)

You should see something like:

hidden: tensor([[-0.0421,  0.1137]], grad_fn=<MulBackward0>)
cell  : tensor([[-0.0952,  0.2334]], grad_fn=<AddBackward0>)

The exact numbers differ every run because the weights are randomly initialised — that’s expected. What matters is that it runs without error and the shapes are right: 1×2 for a batch of 1 and hidden size 2.

Unroll it over a sequence

This is where LSTMs earn their name. Let’s feed it a sequence of length 5:

seq = torch.randn(5, 1, 3)                       # 5 time steps, batch 1, dim 3
h = torch.zeros(1, 2)                            # reset hidden state
c = torch.zeros(1, 2)                            # reset cell state
for t, x_t in enumerate(seq):                    # loop over time
    out, (h, c) = cell(x_t, (h, c))              # one LSTM step
    print(f"step {t}: cell = {c.detach().numpy()}")

Watch the cell state drift from [0, 0] as new information is written and old information is partially kept. That drift is the network’s memory forming. In a real training loop, the gates would learn to keep useful memory and discard noise.

Compare to PyTorch’s built-in

Once you understand this, use the library version — it’s faster and supports batching, CUDA, dropout, etc.:

lstm = nn.LSTM(input_size=3, hidden_size=2, batch_first=True)
seq_batch = torch.randn(1, 5, 3)                 # (batch, seq_len, input)
output, (h_final, c_final) = lstm(seq_batch)
print("output shape:", output.shape)             # (1, 5, 2)

nn.LSTM is the same math we just wrote, vectorised in C++ and bound to Python. Under the hood it is running our six equations per step.

Exercises to try on Colab

  1. Change hidden_size to 8. How does the cell state shape change?
  2. Initialise the forget-gate bias to +2 (a classic LSTM trick). Why might this help in the early stages of training? (Hint: σ(2) ≈ 0.88. The network starts by mostly keeping memory, rather than erasing it randomly. This came out of Jozefowicz et al., 2015 — a paper we will cover later.)
  3. Run the cell 100 times on random input. Plot c over time with matplotlib. Does it explode, drift, or stabilise?

What to take away from the code

The whole LSTM cell is one Linear layer and a handful of element-wise operations. It is not intricate. It is not deep. What made LSTMs work was not engineering complexity — it was the idea that memory should be separate from output and controlled by learned gates. The equations from Section 5 are the whole story, and these 25 lines are their literal translation to PyTorch.

Next: the impact — what LSTMs changed in the real world.