Transformers
The transformer architecture (Vaswani et al., 2017) replaced recurrence with self-attention, enabling parallel training and capturing long-range dependencies without the vanishing gradient problem. Every major language model (GPT, BERT, T5, LLaMA), vision model (ViT, DINO), and multimodal model (CLIP, Stable Diffusion) is built on transformers. This page derives every component from first principles, implements a transformer from scratch, and explains why the architecture scales so effectively.
Why Self-Attention?
RNNs process sequences one token at a time --- to connect token 1 to token 100, information must flow through 99 hidden states. Transformers connect every token to every other token directly through attention.
| Property | RNN | Transformer |
|---|---|---|
| Parallelization | Sequential (slow) | Fully parallel (fast) |
| Long-range dependencies | Signal degrades over distance | Direct O(1) connection |
| Maximum path length | O(n) | O(1) |
| Computation per layer | O(n) | O(n^2) (attention matrix) |
| Training speed | Slow | Fast on GPUs |
Scaled Dot-Product Attention
The Core Equation
where:
are queries ("what am I looking for?") are keys ("what do I contain?") are values ("what information do I provide?") is the dimension of keys/queries
Step-by-Step Derivation
Step 1 --- Compute attention scores:
Step 2 --- Scale:
Without scaling, when
Dividing by
Step 3 --- Softmax:
Each row of the attention weights sums to 1.
Step 4 --- Weighted sum:
Each output is a weighted combination of value vectors, where the weights come from the attention scores.
Worked Example — Self-Attention QKV on a 3-Word Sentence
Input: 3 tokens with
Think of rows as: word "The" (row 0), "cat" (row 1), "sat" (row 2).
Step 1 --- Attention scores
Step 2 --- Scale by
Step 3 --- Softmax (row-wise):
Row 0:
Let me compute precisely:
- Row 0:
, , . Sum = 5.056. - Row 1:
, , . Sum = 5.056. - Row 2:
, , . Sum = 8.169.
Step 4 --- Weighted sum
Output for "The" (row 0):
Output for "cat" (row 1):
Output for "sat" (row 2):
Result: "sat" (row 2) attends most strongly to "The" (50.3%) because their Q and K vectors have the highest dot product (2.0). "cat" attends equally to "The" and "cat" (40.1% each). Each output is a blend of all value vectors, weighted by relevance.
Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (batch, heads, seq_q, d_k)
K: (batch, heads, seq_k, d_k)
V: (batch, heads, seq_k, d_v)
mask: (batch, 1, 1, seq_k) or (batch, 1, seq_q, seq_k)
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weightsMulti-Head Attention
Instead of one attention function, run
where each head is:
with
Typically
Why multiple heads? Different heads can learn different types of relationships: syntactic (subject-verb), semantic (co-reference), positional (adjacent tokens), etc.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Linear projections and reshape to (batch, heads, seq, d_k)
Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads and project
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.n_heads * self.d_k
)
return self.W_o(attn_output)Positional Encoding
Since attention is permutation-equivariant (no notion of order), we must inject position information.
Sinusoidal Positional Encoding
Why sinusoidal? For any fixed offset
Worked Example — Sinusoidal Positional Encoding
Setup:
Dimension indices:
Denominators:
: :
Position 0:
Position 1:
Position 2:
Result: Low-index dimensions (
where
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
"""x: (batch, seq_len, d_model)"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)Learned Positional Embeddings
Modern models (GPT, BERT) often use learned positional embeddings:
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
# Usage: x = x + self.pos_embedding(torch.arange(seq_len, device=x.device))Rotary Position Embedding (RoPE)
Used in LLaMA, Mistral, and most modern LLMs. Encodes position by rotating the query and key vectors:
This directly encodes relative position in the dot product:
Feed-Forward Network
Each transformer layer contains a position-wise feed-forward network (applied identically to each position):
The inner dimension is typically
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
def forward(self, x):
return self.linear2(self.dropout(self.activation(self.linear1(x))))Encoder Layer
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection and layer norm
attn_out = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_out))
# Feed-forward with residual connection and layer norm
ff_out = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_out))
return xDecoder Layer
The decoder adds masked self-attention (causal mask) and cross-attention to the encoder output:
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.cross_attn = MultiHeadAttention(d_model, n_heads)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# Masked self-attention
attn_out = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_out))
# Cross-attention (queries from decoder, keys/values from encoder)
cross_out = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(cross_out))
# Feed-forward
ff_out = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_out))
return xCausal Mask
For autoregressive generation, position
def create_causal_mask(seq_len):
"""Upper-triangular mask: 1 = allowed, 0 = masked."""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq, seq)Full Transformer
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model=512, n_heads=8,
n_enc_layers=6, n_dec_layers=6, d_ff=2048,
max_len=5000, dropout=0.1):
super().__init__()
self.d_model = d_model
# Embeddings
self.src_embed = nn.Embedding(src_vocab, d_model)
self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
# Encoder and decoder stacks
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_enc_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_dec_layers)
])
# Output projection
self.output_proj = nn.Linear(d_model, tgt_vocab)
def encode(self, src, src_mask=None):
x = self.pos_encoding(self.src_embed(src) * math.sqrt(self.d_model))
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x
def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
x = self.pos_encoding(self.tgt_embed(tgt) * math.sqrt(self.d_model))
for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return x
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
enc_output = self.encode(src, src_mask)
dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)
return self.output_proj(dec_output)"Attention Is All You Need" --- Paper Walkthrough
The original transformer paper introduced:
- No recurrence, no convolution --- purely attention-based
- Encoder-decoder structure with 6 layers each
- Multi-head self-attention with 8 heads,
- Positional encoding via sinusoidal functions
- Residual connections + LayerNorm around each sub-layer
- Label smoothing (
) for training - Warmup + inverse square root decay learning rate schedule
Training details:
- WMT 2014 English-German: 4.5M sentence pairs
- 8 NVIDIA P100 GPUs, 3.5 days
- BLEU score: 28.4 (new SOTA)
The learning rate schedule:
This increases linearly during warmup, then decays with the inverse square root.
Why Transformers Scale
Parallelization
RNNs must process tokens sequentially (
Compute-Data Scaling Laws
Kaplan et al. (2020) showed that transformer loss follows power laws:
where
Attention as a Learned Index
Self-attention lets the model dynamically route information based on content, not fixed connectivity. This is a form of learned conditional computation that scales gracefully with model size.
Attention Complexity and Optimization
Standard self-attention is
| Method | Complexity | Approach |
|---|---|---|
| Standard | Full attention matrix | |
| Flash Attention | Tiled computation, no materialization | |
| Multi-Query Attention | Shared K, V across heads | |
| Grouped-Query Attention | Groups of heads share K, V | |
| Sliding Window | Local attention window | |
| Ring Attention | Distribute across devices |
Pre-Norm vs Post-Norm
The original paper uses post-norm (normalize after the residual):
Modern models use pre-norm (normalize before the sublayer):
Pre-norm is more stable for deep networks and requires no learning rate warmup.
KV Cache for Efficient Inference
During autoregressive generation, naive attention recomputes keys and values for all previous tokens at each step. The KV cache stores previously computed keys and values, reducing per-step computation from
class CachedAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, T, C = x.shape
q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
if kv_cache is not None:
# Append to cached keys and values
prev_k, prev_v = kv_cache
k = torch.cat([prev_k, k], dim=2)
v = torch.cat([prev_v, v], dim=2)
new_cache = (k, v)
# Attention with full K, V but only new Q
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(y), new_cacheMemory Savings from KV Cache
Without cache: generating
The trade-off is memory: the KV cache grows linearly with sequence length. For a model with
For LLaMA 70B at 4096 tokens with FP16: approximately 5 GB of KV cache.
Grouped-Query and Multi-Query Attention
Standard multi-head attention uses separate K, V projections per head. This makes the KV cache large.
Multi-Query Attention (MQA): All heads share the same K and V. KV cache shrinks by
Grouped-Query Attention (GQA): Groups of heads share K, V. A compromise between MHA and MQA.
| Method | K, V per head | KV Cache Size | Quality |
|---|---|---|---|
| MHA | Unique per head | Large ( | Best |
| GQA | Shared per group | Medium ( | Near-MHA |
| MQA | Single shared | Small ( | Slightly worse |
LLaMA 2 70B uses GQA with 8 KV heads (vs 64 query heads).
Flash Attention
Flash Attention (Dao et al., 2022) computes exact attention without materializing the full
- Split Q, K, V into blocks that fit in SRAM
- Compute attention block-by-block using online softmax
- Never write the full attention matrix to GPU HBM
Results: 2-4x speedup and 5-20x memory reduction, with exact (not approximate) computation.
# PyTorch 2.0+ has built-in Flash Attention
from torch.nn.functional import scaled_dot_product_attention
# This automatically uses Flash Attention when available
output = scaled_dot_product_attention(Q, K, V, is_causal=True)Transformer Debugging Tips
| Issue | Symptom | Fix |
|---|---|---|
| Attention to wrong positions | Model ignores relevant tokens | Check causal mask is applied correctly |
| Positional encoding missing | Model treats input as a bag of tokens | Add positional encoding/embedding |
| Post-norm instability | Loss spikes during training | Switch to pre-norm |
| Attention overflow | NaN in attention weights | Use scaled dot-product ( |
| Embedding scale | Poor early training | Multiply embeddings by |
Cross-References
- Predecessors: RNN and LSTM --- what transformers replaced
- NLP applications: Language Models --- GPT, BERT, T5
- Vision application: Image Classification --- ViT
- Generative: Text Generation --- decoding strategies
- BERT family: BERT Family --- encoder-only transformers
- Scaling: Model Optimization --- making transformers efficient