Skip to content
Unverified — AI-generated content. Help verify this page

Generative Adversarial Networks

GANs (Goodfellow et al., 2014) train two networks against each other: a generator that creates fake data and a discriminator that tries to tell real from fake. This adversarial training produces stunningly realistic images. This page derives the minimax objective, analyzes mode collapse, introduces WGAN with gradient penalty, implements conditional GANs, builds a GAN from scratch for MNIST, and provides practical training guidance.

The Minimax Game

Setup

  • Generator G(z): takes random noise zpz(z) and outputs fake data G(z)
  • Discriminator D(x): takes data (real or fake) and outputs the probability that it is real

The Objective

minGmaxDV(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]

Discriminator's goal (maximize V):

  • D(x)1 for real data (maximize logD(x))
  • D(G(z))0 for fake data (maximize log(1D(G(z))))

Generator's goal (minimize V):

  • D(G(z))1 for fake data (minimize log(1D(G(z))), i.e., fool the discriminator)

Optimal Discriminator

For a fixed G, the optimal discriminator is:

D(x)=pdata(x)pdata(x)+pg(x)

Derivation: The discriminator maximizes:

V=x[pdata(x)logD(x)+pg(x)log(1D(x))]dx

Taking the derivative with respect to D(x) and setting it to zero:

pdata(x)D(x)pg(x)1D(x)=0

Solving: D(x)=pdatapdata+pg.

Global Optimum

Substituting D back into V:

V(G,D)=log4+2DJS(pdatapg)

where DJS is the Jensen-Shannon divergence. The global minimum is log4 achieved when pg=pdata (the generator perfectly matches the data distribution).

Training Algorithm

for each training iteration:
    # 1. Train Discriminator (k steps, usually k=1)
    Sample mini-batch {x_1, ..., x_m} from data
    Sample mini-batch {z_1, ..., z_m} from noise prior
    Update D by ascending:
        ∇_D [1/m Σ log D(x_i) + 1/m Σ log(1 - D(G(z_i)))]

    # 2. Train Generator (1 step)
    Sample mini-batch {z_1, ..., z_m} from noise prior
    Update G by descending:
        ∇_G [1/m Σ log(1 - D(G(z_i)))]

Non-Saturating Generator Loss

In practice, log(1D(G(z))) provides very small gradients when D(G(z))0 (early in training when the discriminator easily wins). Instead, maximize:

LG=Ezpz[logD(G(z))]

This has the same fixed point but stronger gradients early in training.

Worked Example — Discriminator and Generator Loss

Setup: Mini-batch of 4 samples. Discriminator outputs (probability of being real):

SampleTypeD(x)
x1Real0.9
x2Real0.7
G(z1)Fake0.3
G(z2)Fake0.6

Discriminator loss (wants to maximize V, so we negate for minimization):

LD=12[12logD(xi)+12log(1D(G(zj)))]

Real part: 12[log(0.9)+log(0.7)]=12[0.105+(0.357)]=0.231

Fake part: 12[log(10.3)+log(10.6)]=12[log(0.7)+log(0.4)]=12[0.357+(0.916)]=0.637

LD=(0.231+(0.637))=0.868

Generator loss (non-saturating): LG=12[log(0.3)+log(0.6)]=12[1.204+(0.511)]=0.857

Result: The discriminator is doing well on real images (D=0.9,0.7) but fooled by G(z2) (D=0.6 for a fake). The generator's loss is high for G(z1) (only 0.3 --- easily detected) but lower for G(z2) (0.6 --- partially fooling D). Training will push D to output lower scores for fakes and G to produce more convincing fakes.

Mode Collapse

The most notorious GAN failure mode. The generator produces only a few types of output, ignoring other modes of the data distribution.

Why it happens: The generator finds a single output that consistently fools the discriminator and exploits it, rather than covering the full data distribution.

Example: A GAN trained on MNIST generates only 3s and 7s, ignoring all other digits.

Detecting Mode Collapse

python
def check_mode_collapse(generator, n_samples=1000, n_classes=10):
    """Check if GAN generates diverse outputs."""
    z = torch.randn(n_samples, latent_dim).to(device)
    with torch.no_grad():
        fake = generator(z)

    # Use a pretrained classifier to check diversity
    classifier = load_pretrained_classifier()
    predictions = classifier(fake).argmax(dim=1)
    class_counts = torch.bincount(predictions, minlength=n_classes)

    print("Generated class distribution:")
    for i, count in enumerate(class_counts):
        print(f"  Class {i}: {count.item()} ({100*count.item()/n_samples:.1f}%)")

    # If any class has 0 or >50%, likely mode collapse
    return (class_counts == 0).any() or (class_counts > n_samples * 0.5).any()

Solutions to Mode Collapse

TechniqueHow It Helps
WGAN / WGAN-GPMore stable loss landscape
Minibatch discriminationD can detect lack of diversity
Unrolled GANsG anticipates D's future state
Spectral normalizationControls D's Lipschitz constant
Feature matchingG matches statistics, not specific outputs

WGAN: Wasserstein GAN

The Problem with JS Divergence

When pdata and pg have non-overlapping support (common in high dimensions), the JS divergence is constant (log2), providing zero gradient. The generator cannot learn.

Earth Mover's Distance

The Wasserstein-1 (Earth Mover's) distance measures the minimum cost to transport mass from pg to pdata:

W(pdata,pg)=infγΠ(pdata,pg)E(x,y)γ[xy]

By the Kantorovich-Rubinstein duality:

W(pdata,pg)=supfL1Expdata[f(x)]Expg[f(x)]

where the supremum is over 1-Lipschitz functions.

WGAN Objective

minGmaxDDExpdata[D(x)]Ezpz[D(G(z))]

where D is the set of 1-Lipschitz functions. The discriminator (now called "critic") outputs an unbounded real number, not a probability.

Gradient Penalty (WGAN-GP)

The original WGAN enforced the Lipschitz constraint by weight clipping, which was crude. WGAN-GP (Gulrajani et al., 2017) uses a gradient penalty:

LWGAN-GP=Ez[D(G(z))]Ex[D(x)]Wasserstein distance+λEx^[(x^D(x^)21)2]Gradient penalty

where x^=ϵx+(1ϵ)G(z) with ϵUniform(0,1) (random interpolation between real and fake). λ=10 is standard.

python
def gradient_penalty(discriminator, real, fake, device):
    batch_size = real.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)

    d_interpolated = discriminator(interpolated)
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
    )[0]

    gradients = gradients.view(batch_size, -1)
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty

Conditional GAN (cGAN)

Condition the generator and discriminator on additional information (class label, text, etc.):

minGmaxDEx,y[logD(x,y)]+Ez,y[log(1D(G(z,y),y))]
python
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, n_classes, img_shape):
        super().__init__()
        self.label_emb = nn.Embedding(n_classes, n_classes)

        self.model = nn.Sequential(
            nn.Linear(latent_dim + n_classes, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )
        self.img_shape = img_shape

    def forward(self, z, labels):
        label_embedding = self.label_emb(labels)
        gen_input = torch.cat([z, label_embedding], dim=1)
        img = self.model(gen_input)
        return img.view(img.size(0), *self.img_shape)

From-Scratch GAN: MNIST

python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# ── Hyperparameters ──────────────────────────────────────────────────
LATENT_DIM = 100
IMG_DIM = 784  # 28 x 28
BATCH_SIZE = 128
EPOCHS = 100
LR = 2e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ── Data ─────────────────────────────────────────────────────────────
transform = T.Compose([T.ToTensor(), T.Normalize([0.5], [0.5])])
dataset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, BATCH_SIZE, shuffle=True, drop_last=True)

# ── Generator ────────────────────────────────────────────────────────
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(LATENT_DIM, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, IMG_DIM),
            nn.Tanh(),
        )

    def forward(self, z):
        return self.net(z)

# ── Discriminator ────────────────────────────────────────────────────
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(IMG_DIM, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)

# ── Training ─────────────────────────────────────────────────────────
G = Generator().to(DEVICE)
D = Discriminator().to(DEVICE)

opt_G = torch.optim.Adam(G.parameters(), lr=LR, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=LR, betas=(0.5, 0.999))
criterion = nn.BCELoss()

for epoch in range(EPOCHS):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.view(-1, IMG_DIM).to(DEVICE)
        batch_size = real_imgs.size(0)

        real_labels = torch.ones(batch_size, 1, device=DEVICE)
        fake_labels = torch.zeros(batch_size, 1, device=DEVICE)

        # ── Train Discriminator ──────────────────────────────────
        z = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
        fake_imgs = G(z).detach()

        d_loss_real = criterion(D(real_imgs), real_labels)
        d_loss_fake = criterion(D(fake_imgs), fake_labels)
        d_loss = (d_loss_real + d_loss_fake) / 2

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # ── Train Generator ──────────────────────────────────────
        z = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
        fake_imgs = G(z)
        g_loss = criterion(D(fake_imgs), real_labels)  # Non-saturating loss

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{EPOCHS} | D Loss: {d_loss:.4f} | G Loss: {g_loss:.4f}")

# ── Generate samples ─────────────────────────────────────────────────
G.eval()
with torch.no_grad():
    z = torch.randn(64, LATENT_DIM, device=DEVICE)
    samples = G(z).view(-1, 1, 28, 28).cpu()
    torchvision.utils.save_image(samples, 'gan_samples.png', nrow=8, normalize=True)

Training Tips

Architecture Guidelines

ComponentRecommendation
G activation (hidden)LeakyReLU(0.2) or ReLU
G activation (output)Tanh (images in [-1, 1])
D activation (hidden)LeakyReLU(0.2)
D activation (output)Sigmoid (vanilla) or none (WGAN)
Normalization (G)BatchNorm (not in output layer)
Normalization (D)LayerNorm or SpectralNorm (not BatchNorm with GP)
OptimizerAdam with β1=0.5, β2=0.999
Learning rate104 to 2×104

Stability Tricks

  1. Label smoothing: Use 0.9 instead of 1.0 for real labels
  2. Noisy labels: Occasionally flip labels (5% of the time)
  3. Train D more than G: Especially with WGAN (5 D steps per 1 G step)
  4. Spectral normalization: Stabilizes D without gradient penalty
  5. Two time-scale update rule (TTUR): Higher LR for D than G
  6. Progressive growing: Start at low resolution, gradually increase

Evaluation Metrics

FID (Frechet Inception Distance):

FID=μrμg2+Tr(Σr+Σg2(ΣrΣg)1/2)

Lower FID = better quality and diversity. Compute using features from a pretrained Inception network.

Worked Example — FID Score (Simplified 2D)

Setup: 2D feature space (real FID uses 2048-dim Inception features).

Real images: μr=[3.0,5.0], Σr=[1.00.20.21.5]

Generated images: μg=[3.5,4.5], Σg=[1.20.10.12.0]

Step 1: Mean distance: μrμg2=(33.5)2+(54.5)2=0.25+0.25=0.50

Step 2: Trace terms: Tr(Σr)=1.0+1.5=2.5, Tr(Σg)=1.2+2.0=3.2

Step 3: Matrix square root term: Tr((ΣrΣg)1/2)2.65 (requires eigendecomposition)

Step 4: FID:

FID=0.50+2.5+3.22×2.65=0.50+5.75.3=0.90

Result: FID = 0.90. A perfect generator (μg=μr, Σg=Σr) would give FID = 0. Typical good GANs achieve FID of 5-30 on real datasets; state-of-the-art is below 2 on CIFAR-10.

IS (Inception Score):

IS=exp(Ex[DKL(p(y|x)p(y))])

Higher IS = sharper and more diverse images. Less reliable than FID.

GAN Variants Timeline

YearVariantKey Innovation
2014GANOriginal minimax formulation
2014cGANConditional generation
2016DCGANConvolutional architecture guidelines
2017WGANWasserstein distance
2017WGAN-GPGradient penalty
2018StyleGANStyle-based generator
2019BigGANLarge-scale, class-conditional
2020StyleGAN2Improved normalization
2021StyleGAN3Alias-free generation

DCGAN Architecture Guidelines

DCGAN (Radford et al., 2016) established rules for stable convolutional GANs:

  1. Replace all pooling with strided convolutions (discriminator) and transposed convolutions (generator)
  2. Use BatchNorm in both G and D (except G output and D input)
  3. Remove all fully connected layers (except G input and D output)
  4. Use ReLU in G (except output: Tanh) and LeakyReLU in D
python
class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, channels=1, features=64):
        super().__init__()
        self.net = nn.Sequential(
            # Input: (latent_dim, 1, 1)
            nn.ConvTranspose2d(latent_dim, features * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features * 8),
            nn.ReLU(True),
            # State: (features*8, 4, 4)
            nn.ConvTranspose2d(features * 8, features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 4),
            nn.ReLU(True),
            # State: (features*4, 8, 8)
            nn.ConvTranspose2d(features * 4, features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 2),
            nn.ReLU(True),
            # State: (features*2, 16, 16)
            nn.ConvTranspose2d(features * 2, features, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU(True),
            # State: (features, 32, 32)
            nn.ConvTranspose2d(features, channels, 4, 2, 1, bias=False),
            nn.Tanh(),
            # Output: (channels, 64, 64)
        )

    def forward(self, z):
        return self.net(z.view(z.size(0), -1, 1, 1))

class DCGANDiscriminator(nn.Module):
    def __init__(self, channels=1, features=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channels, features, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features, features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features * 4, features * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x).view(-1, 1)

Spectral Normalization

Spectral normalization (Miyato et al., 2018) constrains the Lipschitz constant of the discriminator by normalizing each weight matrix by its spectral norm (largest singular value):

W¯=Wσ(W)

where σ(W)=maxh=1Wh2 is the spectral norm.

python
# PyTorch built-in
from torch.nn.utils import spectral_norm

discriminator = nn.Sequential(
    spectral_norm(nn.Conv2d(3, 64, 3, padding=1)),
    nn.LeakyReLU(0.2),
    spectral_norm(nn.Conv2d(64, 128, 4, stride=2, padding=1)),
    nn.LeakyReLU(0.2),
    spectral_norm(nn.Conv2d(128, 256, 4, stride=2, padding=1)),
    nn.LeakyReLU(0.2),
    spectral_norm(nn.Linear(256 * 8 * 8, 1)),
)

GAN Applications Beyond Image Generation

ApplicationApproachDescription
Image-to-image translationPix2Pix, CycleGANConvert between domains (sketch to photo)
Super-resolutionSRGAN, ESRGANUpscale low-resolution images
Data augmentationProgressive GANGenerate synthetic training data
Anomaly detectionAnoGANNormal distribution modeling
Drug discoveryMolGANGenerate molecular graphs
Video predictionDVD-GANGenerate future frames

Debugging GAN Training

SymptomLikely CauseFix
D loss drops to 0D too strongReduce D capacity, add noise to D inputs
G loss stuck highG too weak or D too strongIncrease G capacity, reduce D training steps
D loss oscillates wildlyUnstable trainingUse WGAN-GP or spectral normalization
Mode collapse (all same output)G found exploitUse minibatch discrimination, unrolled GAN, or WGAN-GP
Checkerboard artifactsTransposed convolutionUse resize + conv instead of transposed conv
Loss both go to ~0.69Nash equilibriumThis can be normal (log20.693)

Cross-References

"What I cannot create, I do not understand." — Richard Feynman