Skip to content
Unverified — AI-generated content. Help verify this page

Training Techniques

Getting a neural network to train well is as much art as science. The difference between a model that converges to 70% accuracy and one that reaches 95% often comes down to training techniques, not architecture. This page covers the essential toolkit: batch normalization, dropout, weight initialization, learning rate scheduling, gradient clipping, mixed precision, early stopping, and data augmentation.

Batch Normalization

Batch normalization (Ioffe and Szegedy, 2015) normalizes the inputs to each layer, stabilizing training and enabling higher learning rates.

The Math

For a mini-batch B={x1,,xm}:

Step 1 --- Mini-batch mean:

μB=1mi=1mxi

Step 2 --- Mini-batch variance:

σB2=1mi=1m(xiμB)2

Step 3 --- Normalize:

x^i=xiμBσB2+ϵ

Step 4 --- Scale and shift (learnable):

yi=γx^i+β

where γ (scale) and β (shift) are learnable parameters. The ϵ (typically 105) prevents division by zero.

Worked Example — Batch Normalization Calculation

Input: Mini-batch of m=4 values from one feature/channel: B={1.0,3.0,5.0,7.0}

Learnable parameters: γ=1.2, β=0.5, ϵ=105

Step 1 --- Mean:

μB=1+3+5+74=4.0

Step 2 --- Variance:

σB2=(14)2+(34)2+(54)2+(74)24=9+1+1+94=5.0

Step 3 --- Normalize:

xix^i=xi45+105
1.03/2.236=1.342
3.01/2.236=0.447
5.01/2.236=0.447
7.03/2.236=1.342

Step 4 --- Scale and shift:

x^iyi=1.2x^i+0.5
-1.3421.110
-0.4470.037
0.4471.037
1.3422.110

Result: The raw values [1,3,5,7] are normalized to zero mean and unit variance, then re-scaled by learned γ,β. The network can learn to undo BatchNorm (γ=σ,β=μ) if needed, but the gradient landscape is smoother.

Why It Works

  1. Reduces internal covariate shift: Each layer receives inputs with stable statistics, so it doesn't need to constantly readjust to shifting distributions.
  2. Allows higher learning rates: Normalized inputs mean gradients are better behaved.
  3. Acts as regularization: The noise from mini-batch statistics acts like a mild regularizer.
  4. Smooths the loss landscape: Recent research shows BatchNorm makes the optimization landscape smoother (Santurkar et al., 2018).

Training vs Inference

During training, BatchNorm uses the current mini-batch statistics (μB,σB2). During inference, it uses running averages accumulated during training:

μrunning=(1α)μrunning+αμBσrunning2=(1α)σrunning2+ασB2

where α (momentum) is typically 0.1.

model.eval() Is Critical

If you forget to call model.eval() before inference, BatchNorm will use batch statistics instead of running statistics, causing inconsistent results --- especially with batch size 1.

PyTorch Implementation

python
import torch.nn as nn

# For fully connected layers
bn1d = nn.BatchNorm1d(num_features=256)

# For convolutional layers (normalizes per channel)
bn2d = nn.BatchNorm2d(num_features=64)

# In a model
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = self.relu(out)
        return out

Layer Normalization vs Batch Normalization

FeatureBatchNormLayerNorm
Normalizes acrossBatch dimensionFeature dimension
Depends on batch sizeYesNo
Works with small batchesPoorlyWell
Used inCNNsTransformers, RNNs
Training/inference differenceYes (running stats)No

LayerNorm formula (for a single sample):

x^i=xiμσ2+ϵ,μ=1Hi=1Hxi,σ2=1Hi=1H(xiμ)2

where H is the number of features.

Dropout

Dropout (Srivastava et al., 2014) randomly sets neuron activations to zero during training. This prevents co-adaptation --- neurons cannot rely on specific other neurons being present.

The Math

During training, each neuron's output is set to zero with probability p:

hidropped={0with probability phi1pwith probability 1p

The 11p scaling (inverted dropout) ensures the expected value is unchanged: E[hidropped]=hi.

Worked Example — Dropout with Inverted Scaling

Input: Hidden activations h=[0.8,1.2,0.5,2.0], dropout rate p=0.5

Step 1: Generate random mask (sample: keep neurons 0, 2; drop neurons 1, 3)

mask=[1,0,1,0]

Step 2: Apply mask and scale by 11p=10.5=2

hdropped=[0.8×2,0,0.5×2,0]=[1.6,0,1.0,0]

Verify expected value:

  • E[h0dropped]=0.5×0.80.5+0.5×0=0.8 (equals original h0)

During inference: No dropout, use raw h=[0.8,1.2,0.5,2.0] directly (no scaling needed because inverted dropout already compensated during training).

Result: The scaling factor of 2 ensures that the sum of activations at training time has the same expected value as at inference time, so the next layer receives consistent input magnitudes.

During inference, dropout is disabled and all neurons are active (no scaling needed because of inverted dropout).

Where to Place Dropout

python
class ClassifierWithDropout(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Dropout(0.5),       # After activation, before next layer
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),       # Can vary rates by layer
            nn.Linear(256, 10),    # No dropout before output layer
        )

    def forward(self, x):
        return self.features(x)

Dropout Rates by Architecture

ArchitectureTypical Dropout Rate
MLP hidden layers0.5
CNN after conv blocks0.25
CNN classifier head0.5
Transformer attention0.1
Transformer FFN0.1
RNN (between layers)0.2--0.5

Weight Initialization

Proper initialization prevents vanishing and exploding gradients at the start of training. The variance of activations should remain roughly constant across layers.

Xavier/Glorot Initialization

For layers with sigmoid or tanh activations:

WU(6nin+nout,6nin+nout)

or the normal variant:

WN(0,2nin+nout)

Derivation sketch: For the variance of the output to equal the variance of the input, we need Var(W)=2nin+nout, which accounts for both forward and backward passes.

He/Kaiming Initialization

For layers with ReLU activations (He et al., 2015):

WN(0,2nin)

ReLU zeros out half the values, so the variance needs to be doubled compared to Xavier. This is the default for modern networks with ReLU.

PyTorch Initialization

python
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

model.apply(init_weights)

Initialization Summary

ActivationInitializationVariance
Sigmoid/TanhXavier (Glorot)2nin+nout
ReLUHe (Kaiming)2nin
SELULeCun1nin
Any (output layer)Xavier2nin+nout

Learning Rate Scheduling

The learning rate is the most important hyperparameter. A fixed LR is rarely optimal --- scheduling the LR during training almost always improves results.

Step Decay

ηt=η0γt/s

where γ (e.g., 0.1) is the decay factor and s is the step size in epochs.

Worked Example — Step Decay Learning Rate

Setup: η0=0.01, γ=0.1, step size s=30 epochs

Epocht/sηt=0.01×0.1t/30
100.01×1=0.01
2900.01
3010.01×0.1=0.001
5910.001
6020.01×0.01=0.0001

Result: The LR drops by 10x at epochs 30 and 60. This is the "multi-step" schedule commonly used in ResNet papers.

Cosine Annealing

ηt=ηmin+12(ηmaxηmin)(1+cos(tTπ))

Smoothly decays the learning rate from ηmax to ηmin over T steps. No sharp drops. Used in most modern training recipes.

Worked Example — Cosine Annealing Schedule

Setup: ηmax=0.01, ηmin=0.0001, T=100 epochs

ηt=0.0001+12(0.010.0001)(1+cos(t100π))
Epoch tcos(tπ/100)ηt
0cos(0)=1.00.0001+0.00495×2.0=0.0100
25cos(π/4)=0.7070.0001+0.00495×1.707=0.0086
50cos(π/2)=00.0001+0.00495×1.0=0.0051
75cos(3π/4)=0.7070.0001+0.00495×0.293=0.0016
100cos(π)=10.0001+0.00495×0=0.0001

Result: The LR starts at 0.01 and smoothly decays to 0.0001 following a cosine curve. Unlike step decay, there are no abrupt drops. Most of the decay happens in the second half --- at epoch 50, the LR is still 51% of initial.

Cosine Annealing with Warm Restarts

ηt=ηmin+12(ηmaxηmin)(1+cos(TcurTiπ))

Periodically resets the LR, allowing the optimizer to escape local minima.

Warmup + Cosine Decay

Start with a small LR and linearly increase to the target LR over a warmup period, then cosine decay:

python
import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Warmup for 5 epochs, then cosine decay for remaining epochs
warmup_epochs = 5
total_epochs = 100

def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return epoch / warmup_epochs  # Linear warmup
    else:
        # Cosine decay
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        return 0.5 * (1 + np.cos(np.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

One-Cycle Policy

python
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.01,
    steps_per_epoch=len(train_loader),
    epochs=total_epochs,
    pct_start=0.3,       # 30% of training is warmup
    anneal_strategy='cos',
)

# Call scheduler.step() after each BATCH, not each epoch
for epoch in range(total_epochs):
    for batch in train_loader:
        # ... training step ...
        scheduler.step()

LR Finder

Find the optimal learning rate by training with exponentially increasing LR and plotting loss:

python
def lr_finder(model, train_loader, criterion, start_lr=1e-7, end_lr=1, num_steps=100):
    optimizer = optim.SGD(model.parameters(), lr=start_lr)
    lr_mult = (end_lr / start_lr) ** (1 / num_steps)

    lrs, losses = [], []
    best_loss = float('inf')

    for i, (inputs, targets) in enumerate(train_loader):
        if i >= num_steps:
            break

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        current_lr = start_lr * (lr_mult ** i)
        for pg in optimizer.param_groups:
            pg['lr'] = current_lr

        lrs.append(current_lr)
        losses.append(loss.item())

        if loss.item() > 4 * best_loss:
            break
        best_loss = min(best_loss, loss.item())

    # Plot and pick LR where loss decreases fastest
    import matplotlib.pyplot as plt
    plt.semilogx(lrs, losses)
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.title('LR Finder')
    plt.show()

Gradient Clipping

Prevents exploding gradients by capping the gradient norm. Essential for RNNs and helpful for any deep network.

if g>τ:gτgg
python
# Clip gradients to max norm of 1.0
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Usage in training loop:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

By Value

Clips each gradient element independently:

python
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)

Monitoring Gradient Norms

python
def get_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            total_norm += p.grad.data.norm(2).item() ** 2
    return total_norm ** 0.5

# Log this during training to decide clipping threshold
grad_norm = get_grad_norm(model)
print(f"Gradient norm: {grad_norm:.4f}")

Mixed Precision Training

Mixed precision uses float16 for most operations and float32 for numerically sensitive ones (loss computation, gradient accumulation). This doubles throughput and halves memory on modern GPUs.

PyTorch AMP (Automatic Mixed Precision)

python
from torch.amp import autocast, GradScaler

scaler = GradScaler('cuda')

for inputs, targets in train_loader:
    inputs, targets = inputs.to(device), targets.to(device)

    optimizer.zero_grad()

    # Forward pass in mixed precision
    with autocast('cuda'):
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    # Backward pass with gradient scaling
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Why Gradient Scaling?

Float16 has a limited range (6×108 to 6.5×104). Small gradients can underflow to zero. The GradScaler:

  1. Scales the loss by a large factor before backward() (prevents underflow)
  2. Unscales gradients before optimizer.step()
  3. Dynamically adjusts the scale factor (reduces it if overflow/NaN detected)

Memory Savings

PrecisionModel MemoryGradient MemoryTypical Speedup
FP324 bytes/param4 bytes/param1x
Mixed (AMP)2 bytes/param (mostly)4 bytes/param1.5--2x
BF162 bytes/param2 bytes/param1.5--2x

Early Stopping

Stop training when validation loss stops improving to prevent overfitting.

python
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# Usage
early_stop = EarlyStopping(patience=10)

for epoch in range(max_epochs):
    train_one_epoch(model, train_loader)
    val_loss = evaluate(model, val_loader)

    early_stop(val_loss)
    if early_stop.should_stop:
        print(f"Early stopping at epoch {epoch}")
        break

Data Augmentation

Data augmentation artificially increases the training set by applying random transformations. It is the cheapest and most effective regularizer.

Image Augmentation

python
import torchvision.transforms as T

train_transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(15),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    T.RandomErasing(p=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Advanced Augmentation: Mixup

Mixup creates virtual training examples by linearly interpolating pairs:

x~=λxi+(1λ)xjy~=λyi+(1λ)yj

where λBeta(α,α) with α=0.2.

python
def mixup_data(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

Advanced Augmentation: CutMix

CutMix replaces a random rectangular region with a patch from another image:

python
def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)

    # Generate random bounding box
    W, H = x.size(2), x.size(3)
    cut_ratio = np.sqrt(1 - lam)
    cut_w = int(W * cut_ratio)
    cut_h = int(H * cut_ratio)

    cx = np.random.randint(W)
    cy = np.random.randint(H)
    x1 = np.clip(cx - cut_w // 2, 0, W)
    y1 = np.clip(cy - cut_h // 2, 0, H)
    x2 = np.clip(cx + cut_w // 2, 0, W)
    y2 = np.clip(cy + cut_h // 2, 0, H)

    x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2]
    lam = 1 - (x2 - x1) * (y2 - y1) / (W * H)  # Adjust lambda

    return x, y, y[index], lam

Text Augmentation

python
# Common text augmentation techniques
import random

def synonym_replacement(words, n=1):
    """Replace n random words with synonyms (using WordNet)."""
    # Implementation depends on nltk/wordnet
    pass

def random_insertion(words, n=1):
    """Insert n random synonyms at random positions."""
    pass

def random_deletion(words, p=0.1):
    """Delete each word with probability p."""
    return [w for w in words if random.random() > p]

def back_translation(text, src='en', pivot='de'):
    """Translate to another language and back."""
    pass

Putting It All Together: Training Recipe

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler

def train_model(model, train_loader, val_loader, config):
    device = config['device']
    model = model.to(device)

    # Weight init
    model.apply(init_weights)

    # Optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay'],
    )

    # Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config['epochs']
    )

    # Mixed precision
    scaler = GradScaler('cuda')

    # Early stopping
    early_stop = EarlyStopping(patience=config['patience'])

    criterion = nn.CrossEntropyLoss()
    best_acc = 0.0

    for epoch in range(config['epochs']):
        # ── Train ────────────────────────────────────────────────
        model.train()
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Mixup
            if config.get('mixup', False):
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets)

            optimizer.zero_grad()
            with autocast('cuda'):
                outputs = model(inputs)
                if config.get('mixup', False):
                    loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
                else:
                    loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

        scheduler.step()

        # ── Validate ─────────────────────────────────────────────
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, targets).item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        val_acc = 100.0 * correct / total
        avg_val_loss = val_loss / len(val_loader)

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

        early_stop(avg_val_loss)
        if early_stop.should_stop:
            print(f"Early stopping at epoch {epoch + 1}")
            break

    return best_acc

Cross-References

"What I cannot create, I do not understand." — Richard Feynman