Skip to content
Unverified — AI-generated content. Help verify this page

Transformers

The transformer architecture (Vaswani et al., 2017) replaced recurrence with self-attention, enabling parallel training and capturing long-range dependencies without the vanishing gradient problem. Every major language model (GPT, BERT, T5, LLaMA), vision model (ViT, DINO), and multimodal model (CLIP, Stable Diffusion) is built on transformers. This page derives every component from first principles, implements a transformer from scratch, and explains why the architecture scales so effectively.

Why Self-Attention?

RNNs process sequences one token at a time --- to connect token 1 to token 100, information must flow through 99 hidden states. Transformers connect every token to every other token directly through attention.

PropertyRNNTransformer
ParallelizationSequential (slow)Fully parallel (fast)
Long-range dependenciesSignal degrades over distanceDirect O(1) connection
Maximum path lengthO(n)O(1)
Computation per layerO(n)O(n^2) (attention matrix)
Training speedSlowFast on GPUs

Scaled Dot-Product Attention

The Core Equation

Attention(Q,K,V)=softmax(QKTdk)V

where:

  • QRn×dk are queries ("what am I looking for?")
  • KRm×dk are keys ("what do I contain?")
  • VRm×dv are values ("what information do I provide?")
  • dk is the dimension of keys/queries

Step-by-Step Derivation

Step 1 --- Compute attention scores:

S=QKTRn×m

Sij measures how much query i should attend to key j. This is a dot product: Sij=qikj.

Step 2 --- Scale:

Sscaled=Sdk

Without scaling, when dk is large, dot products grow in magnitude, pushing softmax into saturation (where gradients are near zero). If q and k are independent with zero mean and unit variance:

Var(qk)=i=1dkVar(qiki)=dk

Dividing by dk normalizes the variance back to 1.

Step 3 --- Softmax:

αij=exp(Sij/dk)l=1mexp(Sil/dk)

Each row of the attention weights sums to 1. αij is the fraction of attention that position i pays to position j.

Step 4 --- Weighted sum:

outputi=j=1mαijvj

Each output is a weighted combination of value vectors, where the weights come from the attention scores.

Worked Example — Self-Attention QKV on a 3-Word Sentence

Input: 3 tokens with dk=2 (tiny for hand-tracing). Embeddings already projected to Q, K, V:

Q=[100111],K=[110110],V=[10010.50.5]

Think of rows as: word "The" (row 0), "cat" (row 1), "sat" (row 2).

Step 1 --- Attention scores S=QKT:

S=[1(1)+0(1)1(0)+0(1)1(1)+0(0)0(1)+1(1)0(0)+1(1)0(1)+1(0)1(1)+1(1)1(0)+1(1)1(1)+1(0)]=[101110211]

Step 2 --- Scale by dk=21.414:

Sscaled=[0.70700.7070.7070.70701.4140.7070.707]

Step 3 --- Softmax (row-wise):

Row 0: softmax([0.707,0,0.707])=[0.390,0.192,0.390] (Note: 0.028 normalization difference due to rounding)

Let me compute precisely:

  • Row 0: e0.707=2.028, e0=1, e0.707=2.028. Sum = 5.056. [0.401,0.198,0.401]
  • Row 1: e0.707=2.028, e0.707=2.028, e0=1. Sum = 5.056. [0.401,0.401,0.198]
  • Row 2: e1.414=4.113, e0.707=2.028, e0.707=2.028. Sum = 8.169. [0.503,0.248,0.248]

Step 4 --- Weighted sum αV:

Output for "The" (row 0): 0.401[1,0]+0.198[0,1]+0.401[0.5,0.5]=[0.602,0.399]

Output for "cat" (row 1): 0.401[1,0]+0.401[0,1]+0.198[0.5,0.5]=[0.500,0.500]

Output for "sat" (row 2): 0.503[1,0]+0.248[0,1]+0.248[0.5,0.5]=[0.627,0.372]

Result: "sat" (row 2) attends most strongly to "The" (50.3%) because their Q and K vectors have the highest dot product (2.0). "cat" attends equally to "The" and "cat" (40.1% each). Each output is a blend of all value vectors, weighted by relevance.

Implementation

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, heads, seq_q, d_k)
    K: (batch, heads, seq_k, d_k)
    V: (batch, heads, seq_k, d_v)
    mask: (batch, 1, 1, seq_k) or (batch, 1, seq_q, seq_k)
    """
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    return output, attn_weights

Multi-Head Attention

Instead of one attention function, run h attention heads in parallel, each with different learned projections:

MultiHead(Q,K,V)=Concat(head1,,headh)WO

where each head is:

headi=Attention(QWiQ,KWiK,VWiV)

with WiQRdmodel×dk, WiKRdmodel×dk, WiVRdmodel×dv, and WORhdv×dmodel.

Typically dk=dv=dmodel/h. With dmodel=512 and h=8, each head has dk=64.

Why multiple heads? Different heads can learn different types of relationships: syntactic (subject-verb), semantic (co-reference), positional (adjacent tokens), etc.

python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_k = d_model // n_heads
        self.n_heads = n_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # Linear projections and reshape to (batch, heads, seq, d_k)
        Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.n_heads * self.d_k
        )
        return self.W_o(attn_output)

Positional Encoding

Since attention is permutation-equivariant (no notion of order), we must inject position information.

Sinusoidal Positional Encoding

PE(pos,2i)=sin(pos100002i/dmodel)PE(pos,2i+1)=cos(pos100002i/dmodel)

Why sinusoidal? For any fixed offset k, PEpos+k can be represented as a linear function of PEpos. This lets the model learn to attend to relative positions:

Worked Example — Sinusoidal Positional Encoding

Setup: dmodel=4, computing PE for positions 0, 1, 2.

Dimension indices: i=0,1 (so 2i=0,2 and 2i+1=1,3)

Denominators: 100002i/4:

  • i=0: 100000/4=1
  • i=1: 100002/4=100

Position 0:

PE(0,0)=sin(0/1)=0,PE(0,1)=cos(0/1)=1PE(0,2)=sin(0/100)=0,PE(0,3)=cos(0/100)=1PE0=[0,1,0,1]

Position 1:

PE(1,0)=sin(1/1)=0.841,PE(1,1)=cos(1/1)=0.540PE(1,2)=sin(1/100)=0.010,PE(1,3)=cos(1/100)=1.000PE1=[0.841,0.540,0.010,1.000]

Position 2:

PE(2,0)=sin(2)=0.909,PE(2,1)=cos(2)=0.416PE(2,2)=sin(0.02)=0.020,PE(2,3)=cos(0.02)=1.000PE2=[0.909,0.416,0.020,1.000]

Result: Low-index dimensions (i=0) oscillate rapidly (like seconds on a clock), while high-index dimensions (i=1) oscillate slowly (like hours). Position 0 and 1 differ mainly in the fast dimensions; positions 0 and 100 would differ in the slow dimensions too. This multi-frequency encoding allows the model to distinguish positions at any scale.

PEpos+k=AkPEpos

where Ak is a rotation matrix (in each 2D sinusoidal subspace).

python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """x: (batch, seq_len, d_model)"""
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

Learned Positional Embeddings

Modern models (GPT, BERT) often use learned positional embeddings:

python
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
# Usage: x = x + self.pos_embedding(torch.arange(seq_len, device=x.device))

Rotary Position Embedding (RoPE)

Used in LLaMA, Mistral, and most modern LLMs. Encodes position by rotating the query and key vectors:

f(x,pos)=xeiposθ

This directly encodes relative position in the dot product: f(q,m),f(k,n) depends only on mn.

Feed-Forward Network

Each transformer layer contains a position-wise feed-forward network (applied identically to each position):

FFN(x)=max(0,xW1+b1)W2+b2

The inner dimension is typically 4×dmodel:

python
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

Encoder Layer

python
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with residual connection and layer norm
        attn_out = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed-forward with residual connection and layer norm
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

Decoder Layer

The decoder adds masked self-attention (causal mask) and cross-attention to the encoder output:

python
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.cross_attn = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # Masked self-attention
        attn_out = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Cross-attention (queries from decoder, keys/values from encoder)
        cross_out = self.cross_attn(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(cross_out))

        # Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_out))
        return x

Causal Mask

For autoregressive generation, position i should only attend to positions i:

python
def create_causal_mask(seq_len):
    """Upper-triangular mask: 1 = allowed, 0 = masked."""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq, seq)

Full Transformer

python
class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model=512, n_heads=8,
                 n_enc_layers=6, n_dec_layers=6, d_ff=2048,
                 max_len=5000, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Embeddings
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)

        # Encoder and decoder stacks
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_enc_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_dec_layers)
        ])

        # Output projection
        self.output_proj = nn.Linear(d_model, tgt_vocab)

    def encode(self, src, src_mask=None):
        x = self.pos_encoding(self.src_embed(src) * math.sqrt(self.d_model))
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x

    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        x = self.pos_encoding(self.tgt_embed(tgt) * math.sqrt(self.d_model))
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return x

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        enc_output = self.encode(src, src_mask)
        dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)
        return self.output_proj(dec_output)

"Attention Is All You Need" --- Paper Walkthrough

The original transformer paper introduced:

  1. No recurrence, no convolution --- purely attention-based
  2. Encoder-decoder structure with 6 layers each
  3. Multi-head self-attention with 8 heads, dmodel=512
  4. Positional encoding via sinusoidal functions
  5. Residual connections + LayerNorm around each sub-layer
  6. Label smoothing (ϵ=0.1) for training
  7. Warmup + inverse square root decay learning rate schedule

Training details:

  • WMT 2014 English-German: 4.5M sentence pairs
  • 8 NVIDIA P100 GPUs, 3.5 days
  • BLEU score: 28.4 (new SOTA)

The learning rate schedule:

lr=dmodel0.5min(step0.5,stepwarmup1.5)

This increases linearly during warmup, then decays with the inverse square root.

Why Transformers Scale

Parallelization

RNNs must process tokens sequentially (T serial steps). Transformers process all tokens simultaneously --- a single matrix multiplication. GPU utilization jumps from ~30% (RNN) to ~90% (transformer).

Compute-Data Scaling Laws

Kaplan et al. (2020) showed that transformer loss follows power laws:

L(N)NαN,L(D)DαD,L(C)CαC

where N is parameters, D is data, and C is compute. This predictability lets teams plan training runs.

Attention as a Learned Index

Self-attention lets the model dynamically route information based on content, not fixed connectivity. This is a form of learned conditional computation that scales gracefully with model size.

Attention Complexity and Optimization

Standard self-attention is O(n2) in sequence length, which limits context windows. Solutions:

MethodComplexityApproach
StandardO(n2)Full attention matrix
Flash AttentionO(n2) time, O(n) memoryTiled computation, no materialization
Multi-Query AttentionO(n2) but fasterShared K, V across heads
Grouped-Query AttentionO(n2) but fasterGroups of heads share K, V
Sliding WindowO(nw)Local attention window
Ring AttentionO(n2) distributedDistribute across devices

Pre-Norm vs Post-Norm

The original paper uses post-norm (normalize after the residual):

x=LayerNorm(x+SubLayer(x))(post-norm)

Modern models use pre-norm (normalize before the sublayer):

x=x+SubLayer(LayerNorm(x))(pre-norm)

Pre-norm is more stable for deep networks and requires no learning rate warmup.

KV Cache for Efficient Inference

During autoregressive generation, naive attention recomputes keys and values for all previous tokens at each step. The KV cache stores previously computed keys and values, reducing per-step computation from O(n2) to O(n):

python
class CachedAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            # Append to cached keys and values
            prev_k, prev_v = kv_cache
            k = torch.cat([prev_k, k], dim=2)
            v = torch.cat([prev_v, v], dim=2)

        new_cache = (k, v)

        # Attention with full K, V but only new Q
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        att = F.softmax(att, dim=-1)
        y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(y), new_cache

Memory Savings from KV Cache

Without cache: generating n tokens requires O(n2) total computation. With cache: generating n tokens requires O(n) total computation (each step is O(1) for the cached keys/values).

The trade-off is memory: the KV cache grows linearly with sequence length. For a model with L layers, h heads, and dk per head:

KV cache memory=2×L×n×h×dk×bytes_per_element

For LLaMA 70B at 4096 tokens with FP16: approximately 5 GB of KV cache.

Grouped-Query and Multi-Query Attention

Standard multi-head attention uses separate K, V projections per head. This makes the KV cache large.

Multi-Query Attention (MQA): All heads share the same K and V. KV cache shrinks by h×.

Grouped-Query Attention (GQA): Groups of heads share K, V. A compromise between MHA and MQA.

MethodK, V per headKV Cache SizeQuality
MHAUnique per headLarge (h×dk)Best
GQAShared per groupMedium (g×dk)Near-MHA
MQASingle sharedSmall (dk)Slightly worse

LLaMA 2 70B uses GQA with 8 KV heads (vs 64 query heads).

Flash Attention

Flash Attention (Dao et al., 2022) computes exact attention without materializing the full n×n attention matrix, reducing memory from O(n2) to O(n):

  1. Split Q, K, V into blocks that fit in SRAM
  2. Compute attention block-by-block using online softmax
  3. Never write the full attention matrix to GPU HBM

Results: 2-4x speedup and 5-20x memory reduction, with exact (not approximate) computation.

python
# PyTorch 2.0+ has built-in Flash Attention
from torch.nn.functional import scaled_dot_product_attention

# This automatically uses Flash Attention when available
output = scaled_dot_product_attention(Q, K, V, is_causal=True)

Transformer Debugging Tips

IssueSymptomFix
Attention to wrong positionsModel ignores relevant tokensCheck causal mask is applied correctly
Positional encoding missingModel treats input as a bag of tokensAdd positional encoding/embedding
Post-norm instabilityLoss spikes during trainingSwitch to pre-norm
Attention overflowNaN in attention weightsUse scaled dot-product (dk)
Embedding scalePoor early trainingMultiply embeddings by dmodel

Cross-References

"What I cannot create, I do not understand." — Richard Feynman