Skip to content
Unverified — AI-generated content. Help verify this page

Image Segmentation

Image segmentation assigns a class label to every pixel in an image. It is critical for autonomous driving (road vs sidewalk), medical imaging (tumor vs healthy tissue), and satellite imagery (land use). This page covers the three types of segmentation, builds U-Net from scratch, explains DeepLab's atrous convolutions, walks through Mask R-CNN, introduces SAM, and trains a medical image segmenter.

Types of Segmentation

TypeOutputExample
SemanticPer-pixel class label (no instance distinction)All cars = same color
InstancePer-pixel class + instance IDEach car = different color
PanopticSemantic + instance for countable objectsCars are instances, sky is semantic

U-Net

Architecture

U-Net (Ronneberger et al., 2015) has an encoder-decoder structure with skip connections that concatenate encoder features directly to the decoder at each level:

Why Skip Connections?

The encoder loses spatial detail through downsampling. Skip connections provide the decoder with fine-grained spatial information from the encoder, combined with the semantic information from the bottleneck.

U-Net From Scratch

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """Two consecutive Conv-BN-ReLU blocks."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2, features=[64, 128, 256, 512]):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)

        # Encoder
        for f in features:
            self.encoder.append(DoubleConv(in_channels, f))
            in_channels = f

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Decoder
        for f in reversed(features):
            self.decoder.append(
                nn.ConvTranspose2d(f * 2, f, kernel_size=2, stride=2)
            )
            self.decoder.append(DoubleConv(f * 2, f))

        # Final 1x1 convolution
        self.final_conv = nn.Conv2d(features[0], num_classes, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder path
        for enc in self.encoder:
            x = enc(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        # Decoder path
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)  # Upsample
            skip = skip_connections[i // 2]

            # Handle size mismatch (if input is not divisible by 2^n)
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])

            x = torch.cat([skip, x], dim=1)  # Concatenate skip connection
            x = self.decoder[i + 1](x)       # Double conv

        return self.final_conv(x)

# Test
model = UNet(in_channels=1, num_classes=2)
x = torch.randn(1, 1, 256, 256)
out = model(x)
print(f"Input: {x.shape}, Output: {out.shape}")
# Input: (1, 1, 256, 256), Output: (1, 2, 256, 256)

Loss Functions for Segmentation

Dice Loss

The Dice coefficient measures overlap between prediction and ground truth:

Dice=2|PG||P|+|G|=2ipigiipi+igiLDice=12ipigi+ϵipi+igi+ϵ
Worked Example — Dice Coefficient Calculation

Setup: Binary segmentation (tumor vs background). 4x4 pixel region.

Prediction P (after sigmoid, probabilities):

0.90.80.10.0
0.70.60.20.1
0.10.30.10.0
0.00.10.00.0

Ground truth G (binary):

1100
1110
0000
0000

Step 1: Compute soft intersection: pigi

=0.9(1)+0.8(1)+0.1(0)+0(0)+0.7(1)+0.6(1)+0.2(1)+0.1(0)+=0.9+0.8+0.7+0.6+0.2=3.2

Step 2: Compute pi and gi:

pi=0.9+0.8+0.1+0+0.7+0.6+0.2+0.1+0.1+0.3+0.1+0+0+0.1+0+0=4.0gi=5 (five 1s in ground truth)

Step 3: Dice coefficient:

Dice=2×3.24.0+5.0=6.49.0=0.711

Step 4: Dice loss:

LDice=10.711=0.289

Result: Dice = 0.711 means 71.1% overlap between prediction and ground truth. The model correctly identifies most of the tumor region (top-left 2x2) but misses the pixel at (1,2) (predicted 0.2 vs ground truth 1) and has a false positive at (2,1) (predicted 0.3 vs ground truth 0). Dice loss is preferred over BCE for segmentation because it handles class imbalance well --- even if 95% of pixels are background, Dice focuses on the overlap of the foreground class.

python
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        pred = pred.view(-1)
        target = target.view(-1)

        intersection = (pred * target).sum()
        dice = (2 * intersection + self.smooth) / (
            pred.sum() + target.sum() + self.smooth
        )
        return 1 - dice

Combined Loss

L=αLBCE+(1α)LDice

Combining BCE (per-pixel) with Dice (region-level) often works best.

Focal Loss

For class-imbalanced segmentation (e.g., small tumors in large images):

Lfocal=αt(1pt)γlog(pt)

where γ=2 downweights easy pixels and focuses on hard ones.

DeepLab: Atrous (Dilated) Convolution

The Problem

Standard convolutions with pooling reduce spatial resolution. Upsampling loses detail. Atrous convolution increases the receptive field without reducing resolution.

Atrous Convolution

A standard 3x3 kernel samples 9 adjacent pixels. An atrous (dilated) convolution with rate r samples 9 pixels spaced r apart:

(Frk)(p)=s+rt=pF(s)k(t)

Effective receptive field of a 3x3 kernel with dilation r: (2r+1)×(2r+1).

Atrous Spatial Pyramid Pooling (ASPP)

DeepLabv3+ applies multiple atrous convolutions in parallel at different rates:

python
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels=256, rates=[6, 12, 18]):
        super().__init__()
        modules = [
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
            )
        ]
        for rate in rates:
            modules.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3,
                         padding=rate, dilation=rate, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
            ))
        # Global average pooling branch
        modules.append(nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ))
        self.convs = nn.ModuleList(modules)
        self.project = nn.Sequential(
            nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        outputs = []
        for conv in self.convs[:-1]:
            outputs.append(conv(x))
        # Global pooling branch
        gap = self.convs[-1](x)
        gap = F.interpolate(gap, size=x.shape[2:], mode='bilinear', align_corners=False)
        outputs.append(gap)
        return self.project(torch.cat(outputs, dim=1))

Mask R-CNN

Extends Faster R-CNN with a parallel mask prediction branch.

RoIAlign

RoI Pooling uses quantization (rounding to grid cells), losing spatial precision. RoIAlign uses bilinear interpolation to sample features at exact fractional locations:

f(x,y)=i,jfijmax(0,1|xi|)max(0,1|yj|)

This eliminates the quantization error and improves mask quality.

Architecture

The mask head outputs a K×m×m binary mask (K classes, m×m spatial resolution).

Segment Anything Model (SAM)

SAM (Kirillov et al., 2023) is a foundation model for segmentation. Trained on 11M images with 1B masks, it segments any object given a prompt (point, box, or text).

Architecture Components

  1. Image encoder: ViT-H (heavyweight, runs once per image)
  2. Prompt encoder: Encodes points, boxes, or text prompts
  3. Mask decoder: Lightweight transformer decoder that produces masks
python
from segment_anything import SamPredictor, sam_model_registry

sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h.pth")
predictor = SamPredictor(sam)

# Set image (runs image encoder once)
predictor.set_image(image)

# Prompt with a point
masks, scores, logits = predictor.predict(
    point_coords=np.array([[500, 375]]),
    point_labels=np.array([1]),  # 1 = foreground
    multimask_output=True,
)
# Returns 3 masks (ambiguity: part, whole, or background)

Medical Imaging: U-Net for Lung Segmentation

python
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os

class MedicalDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = np.array(Image.open(
            os.path.join(self.image_dir, self.images[idx])
        ).convert('L'), dtype=np.float32) / 255.0
        mask = np.array(Image.open(
            os.path.join(self.mask_dir, self.masks[idx])
        ).convert('L'), dtype=np.float32) / 255.0

        img = torch.from_numpy(img).unsqueeze(0)
        mask = torch.from_numpy(mask).unsqueeze(0)

        if self.transform:
            # Apply same random transform to both
            seed = np.random.randint(2147483647)
            torch.manual_seed(seed)
            img = self.transform(img)
            torch.manual_seed(seed)
            mask = self.transform(mask)

        return img, mask

# ── Training ─────────────────────────────────────────────────────────
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, num_classes=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
dice_loss = DiceLoss()

# train_loader = DataLoader(...)

for epoch in range(50):
    model.train()
    total_loss = 0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks) + dice_loss(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Evaluate with Dice score
    model.eval()
    dice_scores = []
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            preds = torch.sigmoid(model(images)) > 0.5
            intersection = (preds * masks).sum()
            dice = (2 * intersection) / (preds.sum() + masks.sum() + 1e-8)
            dice_scores.append(dice.item())

    print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
          f"Dice={np.mean(dice_scores):.4f}")

Segmentation Metrics

MetricFormulaRange
Pixel Accuracycorrect pixelstotal pixels[0, 1]
IoU (per class)TPTP+FP+FN[0, 1]
mIoUMean IoU across classes[0, 1]
Dice2TP2TP+FP+FN[0, 1]
Boundary F1F1 at boundary pixels[0, 1]

Post-Processing for Segmentation

Connected Component Analysis

Remove small isolated predicted regions:

python
import numpy as np
from scipy import ndimage

def remove_small_components(mask, min_size=100):
    """Remove connected components smaller than min_size pixels."""
    labeled, num_features = ndimage.label(mask)
    sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
    small_components = np.where(sizes < min_size)[0] + 1
    for comp in small_components:
        mask[labeled == comp] = 0
    return mask

Conditional Random Fields (CRF)

CRF post-processing refines segmentation boundaries by considering pixel color similarity:

python
# Using pydensecrf
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax

def crf_refine(image, probs, n_iters=5):
    """Apply dense CRF to refine segmentation probabilities."""
    h, w = image.shape[:2]
    n_classes = probs.shape[0]

    d = dcrf.DenseCRF2D(w, h, n_classes)
    unary = unary_from_softmax(probs)
    d.setUnaryEnergy(unary)

    # Pairwise: appearance (color) + smoothness
    d.addPairwiseGaussian(sxy=3, compat=3)
    d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=image, compat=10)

    Q = d.inference(n_iters)
    return np.array(Q).reshape(n_classes, h, w)

SegFormer: Transformer-Based Segmentation

SegFormer (Xie et al., 2021) uses a hierarchical transformer encoder with a simple MLP decoder:

python
from transformers import SegformerForSemanticSegmentation

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512"
)

# Inference
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)

# Upsample logits to original image size
logits = F.interpolate(
    outputs.logits, size=image.size[::-1],
    mode='bilinear', align_corners=False
)
prediction = logits.argmax(dim=1).squeeze().numpy()

Architecture Comparison for Segmentation

ModelBackboneParamsmIoU (ADE20K)Speed
FCNVGG-16134M29.4Slow
U-NetCustom~31MN/A (medical)Fast
DeepLabv3+ResNet-10163M45.1Medium
SegFormer-B2MiT-B225M46.5Fast
Mask2FormerSwin-L216M57.7Slow
SAMViT-H641MPromptableVery slow

Data Augmentation for Segmentation

The same spatial transform must be applied to both image and mask:

python
import albumentations as A

transform = A.Compose([
    A.RandomCrop(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ElasticTransform(alpha=120, sigma=120 * 0.05, p=0.3),
    A.GridDistortion(p=0.3),
    A.RandomBrightnessContrast(p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Apply to both image and mask simultaneously
augmented = transform(image=image, mask=mask)
aug_image = augmented['image']
aug_mask = augmented['mask']

Multi-Class Segmentation

For K classes, the model outputs K channels. Use cross-entropy loss per pixel:

python
# Model output: (batch, K, H, W)
# Target: (batch, H, W) with integer class labels

criterion = nn.CrossEntropyLoss(
    weight=class_weights,  # Handle imbalance
    ignore_index=255,      # Ignore unlabeled pixels
)

loss = criterion(output, target)

Common Segmentation Pitfalls

IssueSymptomFix
Jagged edgesBoundaries not smoothCRF post-processing, higher resolution
Missing small objectsSmall structures not detectedUse FPN, increase resolution, weighted loss
Class imbalanceBackground dominatesDice loss, focal loss, class weights
Checkerboard artifactsGrid pattern in outputUse bilinear upsampling instead of transposed conv
Train/val inconsistencyVal much worse than trainEnsure same preprocessing for both

Cross-References

"What I cannot create, I do not understand." — Richard Feynman