Autoencoders
Autoencoders learn compressed representations of data by training a network to reconstruct its input through a bottleneck. This page builds from vanilla autoencoders through denoising autoencoders to variational autoencoders (VAEs) with full ELBO derivation, implements a VAE from scratch in PyTorch, generates MNIST digits, and applies autoencoders to anomaly detection.
Vanilla Autoencoder
Architecture
An autoencoder has two parts:
Encoder
Decoder
Loss: Minimize reconstruction error:
Implementation
import torch
import torch.nn as nn
class Autoencoder(nn.Module):
def __init__(self, input_dim=784, latent_dim=32):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, latent_dim),
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid(), # Output in [0, 1] for normalized images
)
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
def encode(self, x):
return self.encoder(x)Limitations of Vanilla Autoencoders
- Irregular latent space: The latent space has no structure. Points between two encoded digits do not necessarily produce meaningful digits.
- No generation: You cannot sample from the latent space because you do not know its distribution.
- Overfitting: The network can memorize training data rather than learning useful features.
Denoising Autoencoder
Denoising autoencoders (Vincent et al., 2008) corrupt the input and train the network to reconstruct the clean version:
This forces the encoder to learn robust features rather than memorizing noise patterns.
class DenoisingAutoencoder(nn.Module):
def __init__(self, input_dim=784, latent_dim=32, noise_factor=0.3):
super().__init__()
self.noise_factor = noise_factor
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, latent_dim),
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid(),
)
def add_noise(self, x):
noise = torch.randn_like(x) * self.noise_factor
return torch.clamp(x + noise, 0.0, 1.0)
def forward(self, x):
if self.training:
x_noisy = self.add_noise(x)
else:
x_noisy = x
z = self.encoder(x_noisy)
return self.decoder(z)Variational Autoencoder (VAE)
VAEs (Kingma and Welling, 2014) solve both problems of vanilla autoencoders by imposing a probability distribution on the latent space.
The Generative Model
We posit a generative process:
- Sample latent variable:
- Generate data:
The goal is to maximize the marginal likelihood:
This integral is intractable (we would need to integrate over all possible
The ELBO Derivation
We introduce an approximate posterior
Start with the log marginal likelihood:
Multiply and divide by
By Jensen's inequality (
Expanding:
This is the ELBO (Evidence Lower Bound):
The gap between
- Maximizes the reconstruction quality (first term)
- Pushes the approximate posterior toward the prior (second term)
KL Divergence: Closed Form
When
Derivation:
For diagonal
Worked Example — KL Divergence for a 2D Latent Space
Input: Encoder outputs for one sample:
So
Step 1: Compute per-dimension terms:
Dimension 0:
Dimension 1:
Step 2: Sum and negate with factor:
Result:
The Reparameterization Trick
We cannot backpropagate through the sampling operation
Now
VAE Architecture
From-Scratch VAE
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=256, latent_dim=20):
super().__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, hidden_dim)
self.fc5 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = F.relu(self.fc3(z))
h = F.relu(self.fc4(h))
return torch.sigmoid(self.fc5(h))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decode(z)
return x_recon, mu, logvar
def vae_loss(x_recon, x, mu, logvar):
"""VAE loss = Reconstruction + KL divergence."""
# Reconstruction loss (binary cross-entropy)
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
# KL divergence
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_lossTraining the VAE on MNIST
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
# Data
transform = T.Compose([T.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
'./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(784, 256, 20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training
for epoch in range(50):
model.train()
total_loss = 0
for images, _ in train_loader:
images = images.view(-1, 784).to(device)
optimizer.zero_grad()
recon, mu, logvar = model(images)
loss = vae_loss(recon, images, mu, logvar)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_dataset)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: Loss = {avg_loss:.2f}")
# ── Generate new digits ──────────────────────────────────────────────
model.eval()
with torch.no_grad():
z = torch.randn(64, 20).to(device) # Sample from prior
generated = model.decode(z).view(-1, 1, 28, 28)
# Visualize
import matplotlib.pyplot as plt
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
ax.imshow(generated[i].cpu().squeeze(), cmap='gray')
ax.axis('off')
plt.suptitle('VAE Generated Digits')
plt.tight_layout()
plt.show()
# ── Latent space interpolation ───────────────────────────────────────
with torch.no_grad():
# Encode two images
img1 = train_dataset[0][0].view(1, 784).to(device)
img2 = train_dataset[1][0].view(1, 784).to(device)
mu1, _ = model.encode(img1)
mu2, _ = model.encode(img2)
# Interpolate
alphas = torch.linspace(0, 1, 10).to(device)
interpolations = []
for alpha in alphas:
z = (1 - alpha) * mu1 + alpha * mu2
img = model.decode(z).view(1, 28, 28)
interpolations.append(img.cpu())
# Plot interpolation
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i, ax in enumerate(axes):
ax.imshow(interpolations[i].squeeze(), cmap='gray')
ax.axis('off')
plt.suptitle('Latent Space Interpolation')
plt.show()Beta-VAE: Disentangled Representations
Beta-VAE (Higgins et al., 2017) weights the KL term:
: Stronger regularization, more disentangled latent space, blurrier reconstructions : Better reconstructions, less disentangled : Standard VAE
Convolutional VAE
For images, use convolutional encoder/decoder:
class ConvVAE(nn.Module):
def __init__(self, latent_dim=32):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1), # 28x28 -> 14x14
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1), # 14x14 -> 7x7
nn.ReLU(),
nn.Flatten(),
)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
# Decoder
self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid(),
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
return mu + torch.randn_like(std) * std
def decode(self, z):
h = self.fc_decode(z).view(-1, 64, 7, 7)
return self.decoder(h)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvarAnomaly Detection with Autoencoders
Autoencoders trained on normal data produce high reconstruction error on anomalies:
def detect_anomalies(model, data_loader, threshold=None, device='cpu'):
"""Detect anomalies based on reconstruction error."""
model.eval()
errors = []
with torch.no_grad():
for batch, _ in data_loader:
batch = batch.view(-1, 784).to(device)
if isinstance(model, VAE):
recon, _, _ = model(batch)
else:
recon = model(batch)
error = F.mse_loss(recon, batch, reduction='none').mean(dim=1)
errors.extend(error.cpu().numpy())
errors = np.array(errors)
if threshold is None:
# Use mean + 2 * std as threshold
threshold = errors.mean() + 2 * errors.std()
anomalies = errors > threshold
print(f"Threshold: {threshold:.4f}")
print(f"Anomalies detected: {anomalies.sum()} / {len(errors)}")
return anomalies, errorsVAE vs GAN
| Feature | VAE | GAN |
|---|---|---|
| Training | Stable (single loss) | Unstable (adversarial) |
| Output quality | Blurry (MSE/BCE loss) | Sharp |
| Latent space | Smooth, interpretable | Less structured |
| Likelihood | Tractable lower bound (ELBO) | No likelihood estimate |
| Mode coverage | Good (covers all modes) | Mode collapse risk |
| Best for | Representation learning, anomaly detection | High-quality generation |
Cross-References
- Competing approach: GANs --- adversarial generation
- Modern generation: Diffusion Models --- state-of-the-art image generation
- Foundations: Neural Network Basics --- backprop, loss functions
- Training: Training Techniques --- regularization, scheduling
- PyTorch: PyTorch Fundamentals --- tensors, modules