Skip to content
Unverified — AI-generated content. Help verify this page

Model Optimization

Production models must be fast, small, and efficient. A 70B-parameter model is useless if it takes 30 seconds to respond. This page covers every major optimization technique: pruning (removing unnecessary weights), quantization (reducing precision), knowledge distillation (training a smaller model), and deployment with ONNX and TensorRT.

Why Optimize?

MetricUnoptimizedOptimizedTechnique
Model size500 MB125 MBINT8 quantization
Inference latency100 ms25 msTensorRT + INT8
Memory usage8 GB2 GBQuantization + pruning
Throughput10 req/s60 req/sBatching + optimization
Mobile deploymentImpossiblePossibleDistillation + quantization

Pruning

Remove unnecessary weights (set them to zero) to reduce model size and computation.

Unstructured Pruning

Remove individual weights based on magnitude:

Wij={Wijif |Wij|>τ0otherwise

where τ is the pruning threshold, often set to achieve a target sparsity (e.g., 90%).

python
import torch
import torch.nn.utils.prune as prune

model = torchvision.models.resnet50(weights='DEFAULT')

# Prune 30% of weights in each Conv2d layer
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.3)

# Check sparsity
total = 0
zero = 0
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        total += module.weight.nelement()
        zero += (module.weight == 0).sum().item()
print(f"Sparsity: {100 * zero / total:.1f}%")

# Make pruning permanent
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module, 'weight')

Structured Pruning

Remove entire channels, filters, or layers. More hardware-friendly (no sparse matrix support needed):

python
# Remove 20% of channels from each Conv2d
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.ln_structured(module, name='weight', amount=0.2, n=2, dim=0)

Iterative Magnitude Pruning (IMP)

The Lottery Ticket Hypothesis (Frankle and Carlin, 2019): dense networks contain sparse subnetworks that can train to the same accuracy.

  1. Train the full network
  2. Prune the smallest 20% of weights
  3. Reset remaining weights to their initial values
  4. Retrain
  5. Repeat

Quantization

Reduce the precision of weights and/or activations from FP32 to INT8 or lower.

Quantization Fundamentals

Map floating-point values to integers:

q=round(xs)+z

where s (scale) and z (zero-point) are calibration parameters:

s=xmaxxminqmaxqmin,z=round(xmins)

Dequantize back to float:

x=s(qz)
Worked Example — INT8 Quantization of a Small Weight Matrix

Input: FP32 weight matrix:

W=[0.351.200.800.501.500.05]

INT8 range: qmin=128, qmax=127

Step 1: Find min/max of weights:

xmin=1.20,xmax=1.50

Step 2: Compute scale and zero-point:

s=1.50(1.20)127(128)=2.70255=0.01059z=round((1.20)0.01059)=round(113.3)=113

Step 3: Quantize each weight q=round(x/s)+z:

Wijround(Wij/0.01059)+z=113Quantized
0.3533146 → clamp to 127127
-1.20-11300
0.8076189 → clamp to 127127
-0.50-476666
1.50142255 → clamp to 127127
0.055118118

Step 4: Dequantize to verify accuracy: x=s(qz)

Quantized qq113×0.01059DequantizedOriginalError
127140.1480.1480.350.202
0-113-1.197-1.197-1.200.003
127140.1480.1480.800.652
66-47-0.498-0.498-0.500.002
127140.1480.1481.501.352
11850.0530.0530.050.003

Result: Values near the extremes (0.35, 0.80, 1.50) all got clamped to 127, causing large errors. This happens because INT8 has limited range. In practice, outlier-aware methods like AWQ handle this by scaling important weights before quantization to protect them from clipping.

Precision Comparison

PrecisionBitsRangeModel Size (7B params)
FP3232±3.4×103828 GB
FP1616±6.5×10414 GB
BF1616±3.4×1038 (less precision)14 GB
INT88-128 to 1277 GB
INT44-8 to 73.5 GB

Post-Training Quantization (PTQ)

Quantize a trained model without retraining:

python
import torch

# Dynamic quantization (weights quantized, activations quantized at runtime)
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear},  # Which layers to quantize
    dtype=torch.qint8,
)

# Compare sizes
import os
torch.save(model.state_dict(), 'fp32_model.pth')
torch.save(quantized_model.state_dict(), 'int8_model.pth')
fp32_size = os.path.getsize('fp32_model.pth') / 1e6
int8_size = os.path.getsize('int8_model.pth') / 1e6
print(f"FP32: {fp32_size:.1f} MB, INT8: {int8_size:.1f} MB")
print(f"Compression: {fp32_size / int8_size:.1f}x")

Static Quantization (Better Accuracy)

Calibrate activation ranges on representative data:

python
# 1. Prepare model
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('x86')
model_prepared = torch.quantization.prepare(model)

# 2. Calibrate on representative data
with torch.no_grad():
    for batch in calibration_loader:
        model_prepared(batch)

# 3. Convert
model_quantized = torch.quantization.convert(model_prepared)

Quantization-Aware Training (QAT)

Simulate quantization during training to learn robust weights:

python
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('x86')
model_qat = torch.quantization.prepare_qat(model)

# Train with fake quantization (simulates INT8 rounding during forward pass)
for epoch in range(num_epochs):
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model_qat(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

# Convert to quantized model
model_quantized = torch.quantization.convert(model_qat.eval())

GPTQ: Post-Training Quantization for LLMs

GPTQ (Frantar et al., 2023) quantizes LLMs to 4-bit with minimal accuracy loss by solving a layer-wise quantization problem:

argminW^WXW^X22

It quantizes one weight at a time, compensating for each quantization error by adjusting the remaining weights.

python
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

quantization_config = GPTQConfig(
    bits=4,
    dataset="c4",
    tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf"),
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=quantization_config,
    device_map="auto",
)
# 7B model: 28GB → ~4GB

AWQ: Activation-aware Weight Quantization

AWQ (Lin et al., 2024) observes that not all weights are equally important. Weights corresponding to large activation channels are more important:

Importance(wj)E[|Xj|]

AWQ scales important weights up before quantization (protecting them from quantization error):

Q(ws)/sw(better approximation for important weights)

TurboQuant: KV Cache Compression (Google, ICLR 2026)

While GPTQ and AWQ compress model weights, TurboQuant targets a different bottleneck: the KV cache that grows linearly with sequence length during inference. For long-context workloads (32K+ tokens), KV cache can consume more memory than the model weights themselves.

Paper: "TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate" — Zandieh, Daliri, Hadian, Mirrokni (Google Research & DeepMind, arXiv:2504.19874)

The two-stage pipeline:

┌─────────────┐    ┌──────────────┐    ┌─────────────────┐
│  KV Vector   │ -> │  PolarQuant  │ -> │  QJL Error      │
│  (FP16/BF16) │    │  (3-4 bits)  │    │  Correction     │
│              │    │              │    │  (1-bit signs)  │
└─────────────┘    └──────────────┘    └─────────────────┘

Stage 1 — PolarQuant:

  1. Apply random orthogonal rotation to spread energy uniformly across coordinates
  2. Convert pairs of Cartesian coordinates to polar (radius + angle)
  3. Recursively reduce until you have one radius + collection of angles
  4. Quantize the predictable angular distributions with optimal scalar quantizers
  5. No per-block normalization needed (unlike standard quantizers)

Stage 2 — QJL (Quantized Johnson-Lindenstrauss) Error Correction:

  1. Project the residual quantization error to a lower-dimensional space
  2. Reduce each value to a single sign bit (+1 or -1)
  3. Use a hybrid estimator: high-precision query + low-precision cache
  4. Eliminates systematic bias in attention score calculations at negligible cost

Why it matters:

  • Data-oblivious — no calibration data, no fine-tuning, works on any transformer
  • 3-bit quantization with zero accuracy loss on Gemma, Mistral, Llama 3.1, Ministral
  • 6x KV cache memory reduction, up to 8x attention speedup on H100 GPUs
  • Complements GPTQ/AWQ — stack weight quantization + KV cache compression together
  • 100% retrieval accuracy on Needle-in-a-Haystack up to 104K tokens at 4x compression

Results across benchmarks:

BitsMemory ReductionQuality ImpactBenchmarks
3.5~5xAbsolute quality neutralityLongBench, RULER, ZeroSCROLLS
3.0~5.5xNegligible degradationNeedle-in-a-Haystack: 100%
2.5~6.5xMarginal degradationL-Eval: within 1-2%

When to use TurboQuant vs weight quantization

  • Short contexts (< 8K): Weight quantization (GPTQ/AWQ) gives most savings
  • Long contexts (32K+): KV cache dominates — TurboQuant is essential
  • Best of both: GPTQ/AWQ for weights + TurboQuant for KV cache = maximum compression

Quantization Comparison for LLMs

MethodTargetBitsCalibrationSpeedQuality
GPTQWeights4Requires dataFast inferenceGood
AWQWeights4Requires dataFastest inferenceBest
TurboQuantKV Cache3-4None (data-oblivious)8x attention speedupNear-lossless
GGUF (llama.cpp)Weights2-8No calibrationGood (CPU)Varies
bitsandbytesWeights4/8NoneTraining + inferenceGood

Knowledge Distillation

Train a small "student" model to mimic a large "teacher" model.

The Distillation Loss

L=αLhard+(1α)T2DKL(softmax(zsT)softmax(ztT))
  • Lhard: standard cross-entropy with true labels
  • zs,zt: student and teacher logits
  • T: temperature (typically 3-20)
  • α: weight (typically 0.1-0.5)

Why temperature? Softmax with high T produces softer probabilities that reveal relationships between classes. A teacher might say "this 7 looks a bit like a 1 and a 9" --- this "dark knowledge" helps the student learn.

Worked Example — Knowledge Distillation Temperature Effect

Setup: Teacher logits for a digit "7": zt=[0.1,1.2,0.3,0.5,0.2,0.1,0.4,8.0,0.3,1.5] (classes 0-9)

T=1 (standard softmax):

  • P(class 7)=0.975, P(class 9)=0.015, P(class 1)=0.010
  • Hard target: almost all probability mass on 7

T=5 (soft softmax, divide logits by 5 first):

  • Scaled logits: [0.02,0.24,0.06,0.10,0.04,0.02,0.08,1.60,0.06,0.30]
  • P(class 7)=0.332, P(class 9)=0.091, P(class 1)=0.085

Result at T=5: The softened distribution reveals that the teacher thinks this "7" is somewhat similar to a "9" (9.1%) and a "1" (8.5%) --- both visually similar digits. A "0" gets only 6.8%. This "dark knowledge" teaches the student about inter-class relationships, not just the correct label. The T2 scaling in the loss compensates for the reduced gradient magnitudes.

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.3):
        super().__init__()
        self.T = temperature
        self.alpha = alpha

    def forward(self, student_logits, teacher_logits, labels):
        # Hard loss (true labels)
        hard_loss = F.cross_entropy(student_logits, labels)

        # Soft loss (teacher's knowledge)
        soft_student = F.log_softmax(student_logits / self.T, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')

        return self.alpha * hard_loss + (1 - self.alpha) * self.T**2 * soft_loss

# Training
teacher = load_teacher_model()  # Large model
teacher.eval()
student = create_student_model()  # Small model
criterion = DistillationLoss(temperature=4.0, alpha=0.3)

for inputs, labels in train_loader:
    with torch.no_grad():
        teacher_logits = teacher(inputs)

    student_logits = student(inputs)
    loss = criterion(student_logits, teacher_logits, labels)
    loss.backward()
    optimizer.step()

ONNX Export

ONNX (Open Neural Network Exchange) is a universal model format:

python
import torch

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    'model.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    },
    opset_version=17,
)

# Verify
import onnx
model_onnx = onnx.load('model.onnx')
onnx.checker.check_model(model_onnx)

# Run with ONNX Runtime
import onnxruntime as ort
session = ort.InferenceSession('model.onnx')
result = session.run(None, {'input': dummy_input.numpy()})

TensorRT Optimization

NVIDIA TensorRT provides kernel fusion, precision calibration, and hardware-specific optimization:

python
import torch
import torch_tensorrt

model = torchvision.models.resnet50(weights='DEFAULT').eval().cuda()

# Compile with TensorRT
trt_model = torch_tensorrt.compile(
    model,
    inputs=[torch_tensorrt.Input(
        shape=[1, 3, 224, 224],
        dtype=torch.float16,
    )],
    enabled_precisions={torch.float16},
    workspace_size=1 << 30,  # 1 GB
)

# Benchmark
import time
input_tensor = torch.randn(1, 3, 224, 224).half().cuda()

# Warmup
for _ in range(50):
    trt_model(input_tensor)
torch.cuda.synchronize()

# Benchmark
start = time.perf_counter()
for _ in range(1000):
    trt_model(input_tensor)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
print(f"TensorRT: {elapsed / 1000 * 1000:.2f} ms per inference")

Mobile Deployment

ExecuTorch (PyTorch Mobile)

python
import torch
from executorch.exir import to_edge

model.eval()
example_input = (torch.randn(1, 3, 224, 224),)

# Export
edge_program = to_edge(torch.export.export(model, example_input))
et_program = edge_program.to_executorch()

# Save
with open('model.pte', 'wb') as f:
    f.write(et_program.buffer)

Optimization Pipeline Summary

Benchmarking and Profiling

Measuring Inference Latency

python
import torch
import time

def benchmark_model(model, input_shape, device='cuda', n_warmup=50, n_runs=200):
    """Benchmark model inference latency."""
    model = model.to(device).eval()
    x = torch.randn(*input_shape, device=device)

    # Warmup
    with torch.no_grad():
        for _ in range(n_warmup):
            model(x)
    torch.cuda.synchronize()

    # Benchmark
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(n_runs):
            model(x)
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    latency_ms = elapsed / n_runs * 1000
    throughput = n_runs / elapsed
    print(f"Latency: {latency_ms:.2f} ms")
    print(f"Throughput: {throughput:.0f} samples/s")
    return latency_ms

# Compare FP32 vs INT8
fp32_latency = benchmark_model(model, (1, 3, 224, 224))
int8_latency = benchmark_model(quantized_model, (1, 3, 224, 224), device='cpu')
print(f"Speedup: {fp32_latency / int8_latency:.2f}x")

PyTorch Profiler

python
from torch.profiler import profile, record_function, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    with record_function("model_inference"):
        model(input_tensor)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# Shows which layers consume the most time and memory

Serving Architectures

Single Model Serving

python
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import base64
from io import BytesIO
from PIL import Image

app = FastAPI()
model = load_optimized_model()  # ONNX or TorchScript

class PredictRequest(BaseModel):
    image_base64: str

@app.post("/predict")
def predict(req: PredictRequest):
    # Decode image
    image_bytes = base64.b64decode(req.image_base64)
    image = Image.open(BytesIO(image_bytes)).convert('RGB')

    # Preprocess
    tensor = preprocess(image).unsqueeze(0)

    # Inference
    with torch.no_grad():
        output = model(tensor)
        probs = torch.softmax(output, dim=1)
        class_idx = probs.argmax().item()
        confidence = probs.max().item()

    return {
        "class": CLASS_NAMES[class_idx],
        "confidence": confidence,
    }

Batched Inference

Accumulate requests and process as a batch for higher throughput:

python
import asyncio
from collections import deque

class BatchPredictor:
    def __init__(self, model, max_batch=32, max_wait_ms=50):
        self.model = model
        self.max_batch = max_batch
        self.max_wait = max_wait_ms / 1000
        self.queue = deque()

    async def predict(self, input_tensor):
        future = asyncio.get_event_loop().create_future()
        self.queue.append((input_tensor, future))

        if len(self.queue) >= self.max_batch:
            self._process_batch()
        else:
            await asyncio.sleep(self.max_wait)
            if not future.done():
                self._process_batch()

        return await future

    def _process_batch(self):
        batch_items = []
        while self.queue and len(batch_items) < self.max_batch:
            batch_items.append(self.queue.popleft())

        inputs = torch.stack([item[0] for item in batch_items])
        with torch.no_grad():
            outputs = self.model(inputs)

        for i, (_, future) in enumerate(batch_items):
            if not future.done():
                future.set_result(outputs[i])

Optimization Decision Matrix

ConstraintTechniqueExpected Improvement
Model too large for deploymentINT8 quantization4x smaller, 2x faster
Latency too highTensorRT + FP162-4x faster
Need to run on mobileDistillation + INT8 + ONNX10-100x smaller
Training too expensiveLoRA fine-tuning100x fewer parameters
Memory limited for LLMGPTQ/AWQ 4-bit4-8x less memory
Edge deployment (no GPU)Pruning + quantization + ONNX RuntimeCPU-friendly

Cross-References

"What I cannot create, I do not understand." — Richard Feynman