Skip to content
Unverified — AI-generated content. Help verify this page

RNN and LSTM

Recurrent neural networks process sequential data by maintaining a hidden state that carries information across time steps. This page derives the vanilla RNN, demonstrates why gradients vanish, builds LSTM and GRU from the equations up, implements them from scratch in PyTorch, and applies them to IMDB sentiment analysis and stock price prediction.

Vanilla RNN

The Equations

At each time step t, the RNN updates its hidden state:

ht=tanh(Whhht1+Wxhxt+bh)yt=Whyht+by

where:

  • xtRd is the input at time t
  • htRn is the hidden state
  • WxhRn×d maps input to hidden
  • WhhRn×n maps hidden to hidden (recurrence)
  • WhyRo×n maps hidden to output

Unrolling Through Time

The same weights Whh, Wxh are shared across all time steps.

Worked Example — RNN Forward Pass Through 3 Timesteps

Setup: Input size d=2, hidden size n=2, 3 timesteps

Wxh=[0.50.30.20.4],Whh=[0.10.60.30.2],bh=[0,0]

Inputs: x1=[1,0], x2=[0,1], x3=[1,1], h0=[0,0]

Timestep 1: h1=tanh(Wxhx1+Whhh0+bh)

Wxhx1=[0.5(1)+0.3(0),0.2(1)+(0.4)(0)]=[0.5,0.2]Whhh0=[0,0]h1=tanh([0.5,0.2])=[0.462,0.197]

Timestep 2: h2=tanh(Wxhx2+Whhh1)

Wxhx2=[0.3,0.4]Whhh1=[0.1(0.462)+0.6(0.197),0.3(0.462)+0.2(0.197)]=[0.164,0.099]h2=tanh([0.464,0.499])=[0.434,0.462]

Timestep 3: h3=tanh(Wxhx3+Whhh2)

Wxhx3=[0.8,0.2]Whhh2=[0.1(0.434)+0.6(0.462),0.3(0.434)+0.2(0.462)]=[0.234,0.222]h3=tanh([0.566,0.422])=[0.513,0.398]

Result: The hidden state evolves: [0,0][0.462,0.197][0.434,0.462][0.513,0.398]. Each step blends the new input with memory from previous steps via Whh. The final hidden state h3 encodes a summary of the entire sequence.

From-Scratch RNN

python
import torch
import torch.nn as nn

class RNNFromScratch(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size) * 0.01)
        self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.b_h = nn.Parameter(torch.zeros(hidden_size))
        self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size) * 0.01)
        self.b_y = nn.Parameter(torch.zeros(output_size))

    def forward(self, x, h_prev=None):
        """
        x: (batch, seq_len, input_size)
        Returns: outputs (batch, seq_len, output_size), h_final
        """
        batch_size, seq_len, _ = x.shape
        if h_prev is None:
            h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)

        outputs = []
        h = h_prev
        for t in range(seq_len):
            h = torch.tanh(x[:, t] @ self.W_xh + h @ self.W_hh + self.b_h)
            y = h @ self.W_hy + self.b_y
            outputs.append(y.unsqueeze(1))

        return torch.cat(outputs, dim=1), h

The Vanishing Gradient Problem

Mathematical Analysis

During backpropagation through time (BPTT), the gradient of the loss at time T with respect to the hidden state at time t involves a product of Jacobians:

hTht=k=t+1Thkhk1

Each factor is:

hkhk1=diag(tanh(zk))Whh

where tanh(z)=1tanh2(z)(0,1].

The gradient magnitude depends on the spectral norm of Whh:

  • If the largest singular value σmax(Whh)<1: gradients vanish exponentially as Tt grows
  • If σmax(Whh)>1: gradients explode exponentially
hTht(σmax(Whh))Tt

For a 100-step sequence with σmax=0.9:

0.91002.66×105

The gradient from the end of the sequence barely reaches the beginning. The network cannot learn long-range dependencies.

Gradient Clipping (Partial Fix)

Gradient clipping handles exploding gradients but not vanishing ones:

python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

LSTM: Long Short-Term Memory

LSTMs (Hochreiter and Schmidhuber, 1997) solve the vanishing gradient problem by introducing a cell state ct that carries information across time with additive (not multiplicative) updates, plus gates that control information flow.

LSTM Equations

Forget gate --- what to erase from the cell state:

ft=σ(Wf[ht1,xt]+bf)

Input gate --- what new information to write:

it=σ(Wi[ht1,xt]+bi)

Candidate cell state --- what to potentially add:

c~t=tanh(Wc[ht1,xt]+bc)

Cell state update --- the key innovation (additive, not multiplicative):

ct=ftct+itc~t

Output gate --- what to expose to the next layer:

ot=σ(Wo[ht1,xt]+bo)

Hidden state:

ht=ottanh(ct)

Why LSTMs Fix Vanishing Gradients

The cell state gradient flows through:

ctct1=ft

Since ft(0,1) is a sigmoid output, the gradient can persist if the forget gate stays close to 1. The network learns when to remember (high ft) and when to forget (low ft). This is an additive path --- no repeated multiplication by weight matrices.

Worked Example — LSTM Gate Values Step by Step

Setup: Scalar simplification (hidden size 1). c0=0, h0=0, input sequence x=[1.0,0.5,0.5]

Assume all gate computations use a single weight and bias (simplified):

  • Forget gate: ft=σ(0.5xt+0.3ht1+1.0) (bias=1 for remember-by-default)
  • Input gate: it=σ(0.4xt+0.2ht1+0.0)
  • Candidate: c~t=tanh(0.6xt+0.1ht1)
  • Output gate: ot=σ(0.3xt+0.4ht1+0.0)

Timestep 1 (x1=1.0, h0=0, c0=0):

  • f1=σ(0.5+0+1.0)=σ(1.5)=0.818 (mostly remember)
  • i1=σ(0.4+0)=σ(0.4)=0.599 (partially write)
  • c~1=tanh(0.6)=0.537
  • c1=0.818×0+0.599×0.537=0.322
  • o1=σ(0.3)=0.574
  • h1=0.574×tanh(0.322)=0.574×0.311=0.178

Timestep 2 (x2=0.5, h1=0.178, c1=0.322):

  • f2=σ(0.25+0.053+1.0)=σ(1.303)=0.786
  • i2=σ(0.2+0.036)=σ(0.236)=0.559
  • c~2=tanh(0.3+0.018)=tanh(0.318)=0.307
  • c2=0.786×0.322+0.559×0.307=0.253+0.172=0.425
  • o2=σ(0.15+0.071)=σ(0.221)=0.555
  • h2=0.555×tanh(0.425)=0.555×0.401=0.223

Timestep 3 (x3=0.5, h2=0.223, c2=0.425):

  • f3=σ(0.25+0.067+1.0)=σ(0.817)=0.694
  • i3=σ(0.2+0.045)=σ(0.155)=0.461
  • c~3=tanh(0.3+0.022)=tanh(0.278)=0.271
  • c3=0.694×0.425+0.461×(0.271)=0.2950.125=0.170
  • o3=σ(0.15+0.089)=σ(0.061)=0.485
  • h3=0.485×tanh(0.170)=0.485×0.169=0.082

Result: The cell state evolved 00.3220.4250.170. The forget gate (0.694--0.818) kept most of the cell state each step (remember-by-default bias of 1.0 is working). The negative input at t=3 caused the candidate to be negative, partially erasing the stored information. The output gate controls how much of the cell state is exposed.

LSTM Gate Visualization

From-Scratch LSTM

python
class LSTMFromScratch(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        # Combined weights for efficiency (all 4 gates at once)
        self.W_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size) * 0.01)
        self.W_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size) * 0.01)
        self.bias = nn.Parameter(torch.zeros(4 * hidden_size))

        # Initialize forget gate bias to 1 (remember by default)
        self.bias.data[hidden_size:2*hidden_size].fill_(1.0)

    def forward(self, x, state=None):
        """
        x: (batch, seq_len, input_size)
        state: tuple of (h, c), each (batch, hidden_size)
        """
        batch_size, seq_len, _ = x.shape
        if state is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = state

        outputs = []
        for t in range(seq_len):
            gates = x[:, t] @ self.W_ih.T + h @ self.W_hh.T + self.bias
            i, f, g, o = gates.chunk(4, dim=1)

            i = torch.sigmoid(i)   # Input gate
            f = torch.sigmoid(f)   # Forget gate
            g = torch.tanh(g)      # Candidate
            o = torch.sigmoid(o)   # Output gate

            c = f * c + i * g      # Cell state update
            h = o * torch.tanh(c)  # Hidden state

            outputs.append(h.unsqueeze(1))

        return torch.cat(outputs, dim=1), (h, c)

Forget Gate Bias

Initialize the forget gate bias to 1. This makes the LSTM remember by default at the start of training, which helps learning. This was proposed by Jozefowicz et al. (2015) and is now standard practice.

GRU: Gated Recurrent Unit

GRU (Cho et al., 2014) simplifies LSTM by merging the cell and hidden state and using only two gates.

GRU Equations

Reset gate:

rt=σ(Wr[ht1,xt]+br)

Update gate:

zt=σ(Wz[ht1,xt]+bz)

Candidate hidden state:

h~t=tanh(Wh[rtht1,xt]+bh)

Hidden state update:

ht=(1zt)ht1+zth~t
Worked Example — GRU Gate Values

Setup: Scalar GRU, h0=0.5, input x1=1.0

Simplified weights: rt=σ(0.5xt0.3ht1), zt=σ(0.4xt+0.2ht1), h~t=tanh(0.6xt+0.3(rtht1))

Step 1 --- Reset gate: r1=σ(0.5(1)0.3(0.5))=σ(0.35)=0.587

Step 2 --- Update gate: z1=σ(0.4(1)+0.2(0.5))=σ(0.5)=0.622

Step 3 --- Candidate: h~1=tanh(0.6(1)+0.3(0.587×0.5))=tanh(0.688)=0.597

Step 4 --- Hidden state: h1=(10.622)(0.5)+0.622(0.597)=0.189+0.371=0.560

Result: The update gate z=0.622 blends 62% new information with 38% old state. The reset gate r=0.587 partially "resets" the old hidden state before computing the candidate, allowing the GRU to selectively forget. The hidden state moved from 0.5 to 0.560, incorporating the new input.

LSTM vs GRU

FeatureLSTMGRU
Gates3 (forget, input, output)2 (reset, update)
States2 (hidden + cell)1 (hidden)
ParametersMore (~4x hidden)Fewer (~3x hidden)
PerformanceSlightly better on long sequencesComparable, faster to train
When to useDefault for sequencesWhen speed matters

Bidirectional RNNs

Process the sequence in both directions, capturing both past and future context:

ht=RNN(ht1,xt)(forward)ht=RNN(ht+1,xt)(backward)ht=[ht;ht](concatenation)
python
# PyTorch bidirectional LSTM
lstm = nn.LSTM(
    input_size=300,
    hidden_size=256,
    num_layers=2,
    batch_first=True,
    bidirectional=True,
    dropout=0.3,
)
# Output shape: (batch, seq_len, 2 * hidden_size)
# h_n shape: (2 * num_layers, batch, hidden_size)

Application 1: IMDB Sentiment Analysis

python
import torch
import torch.nn as nn
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

# ── Data Preparation ─────────────────────────────────────────────────
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for label, text in data_iter:
        yield tokenizer(text)

train_iter = IMDB(split='train')
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab['<unk>'])

def text_pipeline(text):
    return vocab(tokenizer(text))

def collate_batch(batch):
    labels, texts = [], []
    for label, text in batch:
        labels.append(1 if label == 'pos' else 0)
        tokens = text_pipeline(text)[:512]  # Truncate to 512 tokens
        texts.append(torch.tensor(tokens, dtype=torch.long))
    labels = torch.tensor(labels, dtype=torch.float)
    texts = pad_sequence(texts, batch_first=True, padding_value=vocab['<pad>'])
    return texts, labels

train_loader = DataLoader(
    list(IMDB(split='train')), batch_size=64,
    shuffle=True, collate_fn=collate_batch
)
test_loader = DataLoader(
    list(IMDB(split='test')), batch_size=64,
    shuffle=False, collate_fn=collate_batch
)

# ── Model ────────────────────────────────────────────────────────────
class SentimentLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab['<pad>'])
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers,
            batch_first=True, bidirectional=True, dropout=dropout
        )
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        embedded = self.dropout(self.embedding(x))
        output, (h_n, _) = self.lstm(embedded)
        # Concatenate final hidden states from both directions
        hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
        return self.fc(self.dropout(hidden)).squeeze(1)

# ── Training ─────────────────────────────────────────────────────────
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SentimentLSTM(len(vocab), 128, 256, 2).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    model.train()
    total_loss = 0
    for texts, labels in train_loader:
        texts, labels = texts.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(texts)
        loss = criterion(output, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()

    # Evaluate
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for texts, labels in test_loader:
            texts, labels = texts.to(device), labels.to(device)
            preds = (torch.sigmoid(model(texts)) > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
          f"Acc={100*correct/total:.2f}%")
# Expected: ~87-89% accuracy

Application 2: Stock Price Prediction

python
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ── Data ─────────────────────────────────────────────────────────────
def create_sequences(data, seq_length):
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return np.array(X), np.array(y)

# Simulate stock data (replace with real data from yfinance)
np.random.seed(42)
days = 1000
prices = np.cumsum(np.random.randn(days) * 2 + 0.05) + 100
prices = (prices - prices.mean()) / prices.std()  # Normalize

seq_length = 30
X, y = create_sequences(prices, seq_length)

# Split
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]

class StockDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X).unsqueeze(-1)  # (N, seq, 1)
        self.y = torch.FloatTensor(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_loader = DataLoader(StockDataset(X_train, y_train), batch_size=32, shuffle=True)
test_loader = DataLoader(StockDataset(X_test, y_test), batch_size=32)

# ── Model ────────────────────────────────────────────────────────────
class StockLSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=2):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
                           batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(out[:, -1, :]).squeeze(1)  # Use last time step

model = StockLSTM().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(50):
    model.train()
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        loss = criterion(model(X_batch), y_batch)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 10 == 0:
        model.eval()
        with torch.no_grad():
            test_loss = sum(
                criterion(model(x.to(device)), y.to(device)).item()
                for x, y in test_loader
            ) / len(test_loader)
        print(f"Epoch {epoch+1}: Test MSE={test_loss:.4f}")

Stock Prediction Caveat

LSTM stock prediction is a learning exercise. In practice, stock markets are largely efficient and simple LSTM models do not beat buy-and-hold strategies on real data. The technique is useful for other time series tasks (weather, energy, sensor data).

When to Use RNN/LSTM vs Transformers

CriterionRNN/LSTMTransformer
Sequence lengthShort-medium (<500)Any length (with attention)
Training speedSlower (sequential)Faster (parallelizable)
Long-range dependenciesModerate (LSTM helps)Excellent (direct attention)
Memory efficiencyO(1) per stepO(n^2) attention matrix
Streaming/onlineNatural (process one token at a time)Requires full context
2026 recommendationNiche use casesDefault for most tasks

Cross-References

"What I cannot create, I do not understand." — Richard Feynman