Skip to content
Unverified — AI-generated content. Help verify this page

Autoencoders

Autoencoders learn compressed representations of data by training a network to reconstruct its input through a bottleneck. This page builds from vanilla autoencoders through denoising autoencoders to variational autoencoders (VAEs) with full ELBO derivation, implements a VAE from scratch in PyTorch, generates MNIST digits, and applies autoencoders to anomaly detection.

Vanilla Autoencoder

Architecture

An autoencoder has two parts:

Encoder fϕ: maps input x to a latent representation z:

z=fϕ(x),zRd,dD

Decoder gθ: reconstructs the input from the latent:

x^=gθ(z)

Loss: Minimize reconstruction error:

LAE=xx^2=xgθ(fϕ(x))2

Implementation

python
import torch
import torch.nn as nn

class Autoencoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid(),  # Output in [0, 1] for normalized images
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

    def encode(self, x):
        return self.encoder(x)

Limitations of Vanilla Autoencoders

  1. Irregular latent space: The latent space has no structure. Points between two encoded digits do not necessarily produce meaningful digits.
  2. No generation: You cannot sample from the latent space because you do not know its distribution.
  3. Overfitting: The network can memorize training data rather than learning useful features.

Denoising Autoencoder

Denoising autoencoders (Vincent et al., 2008) corrupt the input and train the network to reconstruct the clean version:

x~=x+ϵ,ϵN(0,σ2I)LDAE=xgθ(fϕ(x~))2

This forces the encoder to learn robust features rather than memorizing noise patterns.

python
class DenoisingAutoencoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=32, noise_factor=0.3):
        super().__init__()
        self.noise_factor = noise_factor
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid(),
        )

    def add_noise(self, x):
        noise = torch.randn_like(x) * self.noise_factor
        return torch.clamp(x + noise, 0.0, 1.0)

    def forward(self, x):
        if self.training:
            x_noisy = self.add_noise(x)
        else:
            x_noisy = x
        z = self.encoder(x_noisy)
        return self.decoder(z)

Variational Autoencoder (VAE)

VAEs (Kingma and Welling, 2014) solve both problems of vanilla autoencoders by imposing a probability distribution on the latent space.

The Generative Model

We posit a generative process:

  1. Sample latent variable: zp(z)=N(0,I)
  2. Generate data: xpθ(x|z)

The goal is to maximize the marginal likelihood:

pθ(x)=pθ(x|z)p(z)dz

This integral is intractable (we would need to integrate over all possible z).

The ELBO Derivation

We introduce an approximate posterior qϕ(z|x) (the encoder) and derive the Evidence Lower Bound.

Start with the log marginal likelihood:

logpθ(x)=logpθ(x|z)p(z)dz

Multiply and divide by qϕ(z|x):

logpθ(x)=logpθ(x|z)p(z)qϕ(z|x)qϕ(z|x)dz

By Jensen's inequality (log is concave):

logpθ(x)qϕ(z|x)logpθ(x|z)p(z)qϕ(z|x)dz

Expanding:

logpθ(x)Eqϕ(z|x)[logpθ(x|z)]ReconstructionDKL(qϕ(z|x)p(z))Regularization

This is the ELBO (Evidence Lower Bound):

LELBO=Eqϕ(z|x)[logpθ(x|z)]DKL(qϕ(z|x)p(z))

The gap between logpθ(x) and the ELBO is exactly DKL(qϕ(z|x)pθ(z|x)), which is non-negative. Maximizing the ELBO simultaneously:

  • Maximizes the reconstruction quality (first term)
  • Pushes the approximate posterior toward the prior (second term)

KL Divergence: Closed Form

When qϕ(z|x)=N(μ,diag(σ2)) and p(z)=N(0,I):

DKL(qp)=12j=1d(1+log(σj2)μj2σj2)

Derivation:

DKL=q(z)logq(z)p(z)dz=12[tr(Σ)+μTμdlogdet(Σ)]

For diagonal Σ=diag(σ12,,σd2):

=12[jσj2+jμj2djlogσj2]=12j(1+logσj2μj2σj2)
Worked Example — KL Divergence for a 2D Latent Space

Input: Encoder outputs for one sample: μ=[0.5,1.0], logσ2=[0.5,0.3]

So σ2=[e0.5,e0.3]=[0.607,1.350]

Step 1: Compute per-dimension terms: 1+logσj2μj2σj2

Dimension 0: 1+(0.5)(0.5)20.607=10.50.250.607=0.357

Dimension 1: 1+0.3(1.0)21.350=1+0.31.01.350=1.050

Step 2: Sum and negate with factor:

DKL=12(0.357+(1.050))=12(1.407)=0.704

Result: DKL=0.704. Dimension 1 contributes more (1.050/2=0.525) because μ1=1.0 is far from the prior mean of 0, and σ12=1.35 differs from the prior variance of 1. If μ=[0,0] and σ2=[1,1] (matching the prior exactly), DKL=0.

The Reparameterization Trick

We cannot backpropagate through the sampling operation zqϕ(z|x). The reparameterization trick makes sampling differentiable:

z=μ+σϵ,ϵN(0,I)

Now z is a deterministic function of μ, σ, and the random noise ϵ. Gradients flow through μ and σ to the encoder.

VAE Architecture

From-Scratch VAE

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256, latent_dim=20):
        super().__init__()
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc3(z))
        h = F.relu(self.fc4(h))
        return torch.sigmoid(self.fc5(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

def vae_loss(x_recon, x, mu, logvar):
    """VAE loss = Reconstruction + KL divergence."""
    # Reconstruction loss (binary cross-entropy)
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + kl_loss

Training the VAE on MNIST

python
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# Data
transform = T.Compose([T.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
    './data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(784, 256, 20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training
for epoch in range(50):
    model.train()
    total_loss = 0
    for images, _ in train_loader:
        images = images.view(-1, 784).to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(images)
        loss = vae_loss(recon, images, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataset)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.2f}")

# ── Generate new digits ──────────────────────────────────────────────
model.eval()
with torch.no_grad():
    z = torch.randn(64, 20).to(device)  # Sample from prior
    generated = model.decode(z).view(-1, 1, 28, 28)

    # Visualize
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(8, 8, figsize=(8, 8))
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated[i].cpu().squeeze(), cmap='gray')
        ax.axis('off')
    plt.suptitle('VAE Generated Digits')
    plt.tight_layout()
    plt.show()

# ── Latent space interpolation ───────────────────────────────────────
with torch.no_grad():
    # Encode two images
    img1 = train_dataset[0][0].view(1, 784).to(device)
    img2 = train_dataset[1][0].view(1, 784).to(device)
    mu1, _ = model.encode(img1)
    mu2, _ = model.encode(img2)

    # Interpolate
    alphas = torch.linspace(0, 1, 10).to(device)
    interpolations = []
    for alpha in alphas:
        z = (1 - alpha) * mu1 + alpha * mu2
        img = model.decode(z).view(1, 28, 28)
        interpolations.append(img.cpu())

    # Plot interpolation
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))
    for i, ax in enumerate(axes):
        ax.imshow(interpolations[i].squeeze(), cmap='gray')
        ax.axis('off')
    plt.suptitle('Latent Space Interpolation')
    plt.show()

Beta-VAE: Disentangled Representations

Beta-VAE (Higgins et al., 2017) weights the KL term:

Lβ-VAE=E[logpθ(x|z)]βDKL(qϕ(z|x)p(z))
  • β>1: Stronger regularization, more disentangled latent space, blurrier reconstructions
  • β<1: Better reconstructions, less disentangled
  • β=1: Standard VAE

Convolutional VAE

For images, use convolutional encoder/decoder:

python
class ConvVAE(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),  # 28x28 -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), # 14x14 -> 7x7
            nn.ReLU(),
            nn.Flatten(),
        )
        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)

        # Decoder
        self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + torch.randn_like(std) * std

    def decode(self, z):
        h = self.fc_decode(z).view(-1, 64, 7, 7)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

Anomaly Detection with Autoencoders

Autoencoders trained on normal data produce high reconstruction error on anomalies:

python
def detect_anomalies(model, data_loader, threshold=None, device='cpu'):
    """Detect anomalies based on reconstruction error."""
    model.eval()
    errors = []
    with torch.no_grad():
        for batch, _ in data_loader:
            batch = batch.view(-1, 784).to(device)
            if isinstance(model, VAE):
                recon, _, _ = model(batch)
            else:
                recon = model(batch)
            error = F.mse_loss(recon, batch, reduction='none').mean(dim=1)
            errors.extend(error.cpu().numpy())

    errors = np.array(errors)

    if threshold is None:
        # Use mean + 2 * std as threshold
        threshold = errors.mean() + 2 * errors.std()

    anomalies = errors > threshold
    print(f"Threshold: {threshold:.4f}")
    print(f"Anomalies detected: {anomalies.sum()} / {len(errors)}")
    return anomalies, errors

VAE vs GAN

FeatureVAEGAN
TrainingStable (single loss)Unstable (adversarial)
Output qualityBlurry (MSE/BCE loss)Sharp
Latent spaceSmooth, interpretableLess structured
LikelihoodTractable lower bound (ELBO)No likelihood estimate
Mode coverageGood (covers all modes)Mode collapse risk
Best forRepresentation learning, anomaly detectionHigh-quality generation

Cross-References

"What I cannot create, I do not understand." — Richard Feynman