Benchmarks and Optimization: From 48GB GPU to 8GB RTX
You have a model. It works on an 80 GB A100. But you need to deploy it on a 24 GB RTX 3090, or on an 8 GB laptop RTX 4060, or even on a Raspberry Pi. How do you know how much accuracy you lose moving from FP32 to INT4? How much speed do you gain with Flash Attention? Is it worth quantizing or is distillation better? How much memory does gradient checkpointing save?
Without systematic benchmarks, these questions remain unanswered — and you end up making suboptimal choices based on intuitions or benchmarks published with configurations different from your own. In this final article of the series, we build a complete benchmarking framework to measure every dimension of performance: memory, latency, throughput, accuracy, and power consumption.
We then systematically apply all the techniques seen in the series — quantization, pruning, distillation, Flash Attention, gradient checkpointing, mixed precision — and show how to go from a model requiring 48 GB to one running in 8 GB, with metrics that demonstrate exactly what you pay in terms of quality.
What You'll Learn
- Systematic benchmarking framework for DL models
- Precisely measuring VRAM, latency, throughput, and FLOPs
- Mixed Precision Training: FP16 vs BF16 vs FP32
- Flash Attention 2/3: how much it saves and when to use it
- Gradient Checkpointing: memory vs compute trade-off
- Gradient Accumulation: virtually large batch sizes
- torch.compile and runtime optimizations
- KV Cache: optimization for LLM autoregressive inference
- Systematic comparison: all techniques compared
- Decision guide: which optimization for which scenario
Systematic Benchmarking Framework
Before optimizing, you must measure precisely. A professional benchmarking framework measures: peak VRAM usage, mean and P95 latency, throughput (token/s or img/s), FLOPs, power consumption, and accuracy on specific tasks. The key is reproducibility: benchmarks that vary by 10% between runs are worthless.
import torch
import torch.nn as nn
import time
import numpy as np
from dataclasses import dataclass, asdict
from typing import Optional, Callable
import gc
# ============================================================
# DATACLASS FOR BENCHMARK RESULTS
# ============================================================
@dataclass
class BenchmarkResult:
"""Complete results from a benchmark."""
name: str
# Memory
vram_allocated_mb: float
vram_reserved_mb: float
vram_peak_mb: float
# Speed
latency_ms_mean: float
latency_ms_p50: float
latency_ms_p95: float
latency_ms_p99: float
throughput_per_sec: float
# Model
params_total: int
params_trainable: int
model_size_mb: float
# Optional
accuracy: Optional[float] = None
flops_total: Optional[float] = None
power_watts: Optional[float] = None
def print_summary(self):
print(f"\n=== {self.name} ===")
print(f" VRAM: {self.vram_peak_mb:.0f} MB peak, {self.vram_allocated_mb:.0f} MB alloc")
print(f" Latency: {self.latency_ms_mean:.1f}ms mean, "
f"{self.latency_ms_p95:.1f}ms p95, {self.latency_ms_p99:.1f}ms p99")
print(f" Throughput: {self.throughput_per_sec:.1f}/s")
print(f" Parameters: {self.params_total:,} ({self.model_size_mb:.1f} MB)")
if self.accuracy:
print(f" Accuracy: {self.accuracy:.4f}")
# ============================================================
# MAIN BENCHMARKING CLASS
# ============================================================
class DeepLearningBenchmark:
def __init__(self, device: str = "cuda"):
self.device = device
self.results = []
def _count_params(self, model: nn.Module) -> tuple:
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainable
def _model_size_mb(self, model: nn.Module) -> float:
total_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
return total_bytes / (1024 ** 2)
def _reset_memory(self):
"""Reset GPU memory for clean benchmark."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def benchmark_inference(
self,
name: str,
model: nn.Module,
input_fn: Callable[[], tuple],
n_warmup: int = 10,
n_runs: int = 100,
batch_size: int = 1
) -> BenchmarkResult:
"""
Complete inference benchmark.
input_fn: function that returns inputs for the model
"""
model = model.to(self.device).eval()
self._reset_memory()
# Warmup
with torch.no_grad():
for _ in range(n_warmup):
inputs = input_fn()
if isinstance(inputs, dict):
model(**{k: v.to(self.device) for k, v in inputs.items()})
else:
model(inputs.to(self.device))
# Memory after warmup
if torch.cuda.is_available():
mem_alloc = torch.cuda.memory_allocated() / (1024**2)
mem_reserved = torch.cuda.memory_reserved() / (1024**2)
# Actual benchmark
torch.cuda.synchronize() if torch.cuda.is_available() else None
latencies = []
for _ in range(n_runs):
inputs = input_fn()
t0 = time.perf_counter()
with torch.no_grad():
if isinstance(inputs, dict):
_ = model(**{k: v.to(self.device) for k, v in inputs.items()})
else:
_ = model(inputs.to(self.device))
torch.cuda.synchronize() if torch.cuda.is_available() else None
latencies.append((time.perf_counter() - t0) * 1000)
if torch.cuda.is_available():
mem_peak = torch.cuda.max_memory_allocated() / (1024**2)
else:
mem_alloc = mem_reserved = mem_peak = 0.0
latencies = np.array(latencies)
total_params, trainable_params = self._count_params(model)
result = BenchmarkResult(
name=name,
vram_allocated_mb=mem_alloc,
vram_reserved_mb=mem_reserved,
vram_peak_mb=mem_peak,
latency_ms_mean=float(np.mean(latencies)),
latency_ms_p50=float(np.percentile(latencies, 50)),
latency_ms_p95=float(np.percentile(latencies, 95)),
latency_ms_p99=float(np.percentile(latencies, 99)),
throughput_per_sec=1000 / np.mean(latencies) * batch_size,
params_total=total_params,
params_trainable=trainable_params,
model_size_mb=self._model_size_mb(model)
)
result.print_summary()
self.results.append(result)
return result
def benchmark_training_step(
self,
name: str,
model: nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Callable,
input_fn: Callable,
n_steps: int = 50
) -> dict:
"""Benchmark a single training step."""
model = model.to(self.device).train()
self._reset_memory()
latencies = []
for step in range(n_steps):
inputs, labels = input_fn()
inputs = inputs.to(self.device)
labels = labels.to(self.device)
t0 = time.perf_counter()
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
torch.cuda.synchronize() if torch.cuda.is_available() else None
latencies.append((time.perf_counter() - t0) * 1000)
return {
"name": name,
"vram_peak_mb": torch.cuda.max_memory_allocated() / (1024**2) if torch.cuda.is_available() else 0,
"step_ms_mean": float(np.mean(latencies[5:])), # Skip warmup
"step_ms_p95": float(np.percentile(latencies[5:], 95))
}
def compare_results(self) -> None:
"""Print comparative table of all results."""
if not self.results:
print("No results available.")
return
baseline = self.results[0]
print(f"\n{'Config':<30} {'VRAM (MB)':>12} {'Latency (ms)':>14} {'Throughput':>12} {'Speedup':>10}")
print("-" * 82)
for r in self.results:
speedup = baseline.latency_ms_mean / r.latency_ms_mean
print(f"{r.name:<30} {r.vram_peak_mb:>12.0f} {r.latency_ms_mean:>14.2f} "
f"{r.throughput_per_sec:>12.1f} {speedup:>10.2f}x")
# Usage:
bench = DeepLearningBenchmark(device="cuda" if torch.cuda.is_available() else "cpu")
print("Benchmarking framework initialized")
Mixed Precision: FP32 vs FP16 vs BF16
Mixed precision training is the first optimization to enable: nearly zero
configuration overhead, 2x memory savings, often 2-3x speedup on Ampere+ hardware.
torch.autocast automatically manages which operations to run in reduced
precision.
The key difference between FP16 and BF16 is the binary format: FP16 has 5 exponent bits and 10 mantissa bits (range 6e-5 to 6.5e4), while BF16 has 8 exponent bits and 7 mantissa bits (same range as FP32, from 1.2e-38 to 3.4e38). BF16 is much more stable during training because it does not cause overflow/underflow with large gradients.
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler
# ============================================================
# FP32 vs FP16 vs BF16 COMPARISON
# ============================================================
def train_step_fp32(model, optimizer, imgs, labels, criterion):
"""Standard FP32 training step."""
optimizer.zero_grad()
output = model(imgs)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
return loss.item()
def train_step_fp16(model, optimizer, imgs, labels, criterion, scaler: GradScaler):
"""
AMP FP16 training step.
GradScaler required: FP16 has limited range, loss scaling prevents underflow.
"""
optimizer.zero_grad()
with torch.autocast(device_type="cuda", dtype=torch.float16):
output = model(imgs)
loss = criterion(output, labels)
# Scale loss to prevent underflow in FP16
scaler.scale(loss).backward()
# Unscale gradients before clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Update weights (skip if NaN/Inf in gradients)
scaler.step(optimizer)
scaler.update()
return loss.item()
def train_step_bf16(model, optimizer, imgs, labels, criterion):
"""
BF16 training step.
BF16 does NOT require GradScaler: same dynamic range as FP32.
Available on: A100, RTX 3000+, Apple M-series.
"""
optimizer.zero_grad()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(imgs)
loss = criterion(output, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return loss.item()
# Comparative benchmark
from torchvision import models
import time, gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def compare_precisions(model_fn=models.resnet50, n_steps=100,
batch_size=32, img_size=224):
"""Compare FP32, FP16, BF16 for training and inference."""
criterion = nn.CrossEntropyLoss()
configs = [
("FP32", torch.float32, False),
("FP16", torch.float16, True), # Requires GradScaler
("BF16", torch.bfloat16, False) # No GradScaler
]
results = {}
for name, dtype, use_scaler in configs:
model = model_fn(pretrained=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scaler = GradScaler() if use_scaler else None
# Reset memory stats
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
timings = []
for step in range(n_steps):
imgs = torch.randn(batch_size, 3, img_size, img_size, device=device)
labels = torch.randint(0, 1000, (batch_size,), device=device)
t0 = time.perf_counter()
with torch.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
out = model(imgs)
loss = criterion(out, labels)
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.cuda.synchronize() if torch.cuda.is_available() else None
timings.append((time.perf_counter() - t0) * 1000)
vram_peak = torch.cuda.max_memory_allocated() / (1024**2) if torch.cuda.is_available() else 0
results[name] = {
"vram_mb": round(vram_peak, 1),
"step_ms": round(np.mean(timings[10:]), 2),
"throughput_imgs_s": round(batch_size * 1000 / np.mean(timings[10:]), 1)
}
print(f"{name}: VRAM={vram_peak:.0f}MB, {np.mean(timings[10:]):.1f}ms/step, "
f"{batch_size*1000/np.mean(timings[10:]):.0f} img/s")
return results
# Typical results ResNet-50 BS=32 on RTX 4090:
# FP32: VRAM=6200MB, 95ms/step, 336 img/s
# FP16: VRAM=3100MB, 41ms/step, 780 img/s (2x speed, 50% VRAM)
# BF16: VRAM=3100MB, 38ms/step, 842 img/s (2.2x speed, 50% VRAM)
Flash Attention: The Optimization That Changes the Rules
Flash Attention (Dao et al., 2022) is perhaps the most impactful optimization for Transformers in recent years. It reformulates the attention computation to be IO-bound aware: instead of materializing the full attention matrix in HBM (with O(n^2) memory complexity), it computes attention in blocks staying in SRAM. The result: O(n) memory complexity instead of O(n^2), 2-4x speed on long sequences.
Flash Attention 2 (2023) further improves parallelism on GPU, achieving 72% of the theoretical FP16 FLOPS utilization. Flash Attention 3 (2024) adds FP8 support and Hopper-specific optimizations, with up to 2x speedup over FA2.
import torch
import torch.nn as nn
import torch.nn.functional as F
import time, math
# ============================================================
# FLASH ATTENTION vs STANDARD ATTENTION: COMPARISON
# ============================================================
def standard_attention(q, k, v, scale=None):
"""
Standard attention: materializes the full NxN matrix in GPU memory.
Memory complexity: O(N^2 * d_head)
"""
if scale is None:
scale = q.size(-1) ** -0.5
# [B, heads, N, N] - this matrix can be HUGE for long sequences!
attn = torch.softmax((q @ k.transpose(-2, -1)) * scale, dim=-1)
return attn @ v
def flash_attention_native(q, k, v):
"""
Flash Attention via PyTorch 2.0+ scaled_dot_product_attention.
Automatically chooses the optimal implementation:
- FlashAttention-2 if available (CUDA Ampere+)
- Memory-efficient attention (xFormers) as fallback
- Standard attention as last resort
"""
return F.scaled_dot_product_attention(q, k, v, is_causal=False)
def benchmark_attention_implementations(
batch_size=4, n_heads=12, seq_lengths=[512, 1024, 2048, 4096, 8192],
d_head=64, device="cuda"
):
"""
Compare Standard vs Flash Attention across sequence lengths.
"""
print(f"{'Seq Len':>10} | {'Standard (ms)':>15} | {'Flash (ms)':>12} | "
f"{'Speedup':>10} | {'VRAM Std (MB)':>15} | {'VRAM Flash (MB)':>15}")
print("-" * 90)
for seq_len in seq_lengths:
q = torch.randn(batch_size, n_heads, seq_len, d_head, device=device, dtype=torch.float16)
k = torch.randn_like(q)
v = torch.randn_like(q)
# Warmup
for _ in range(5):
standard_attention(q, k, v)
flash_attention_native(q, k, v)
# Benchmark Standard
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(20):
out_std = standard_attention(q, k, v)
torch.cuda.synchronize()
std_ms = (time.perf_counter() - t0) / 20 * 1000
vram_std = torch.cuda.max_memory_allocated() / (1024**2)
# Benchmark Flash Attention
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(20):
out_flash = flash_attention_native(q, k, v)
torch.cuda.synchronize()
flash_ms = (time.perf_counter() - t0) / 20 * 1000
vram_flash = torch.cuda.max_memory_allocated() / (1024**2)
speedup = std_ms / flash_ms
print(f"{seq_len:>10} | {std_ms:>15.2f} | {flash_ms:>12.2f} | "
f"{speedup:>10.2f}x | {vram_std:>15.0f} | {vram_flash:>15.0f}")
# Typical results on RTX 4090 (FP16, B=4, heads=12, d_head=64):
# Seq Len | Standard (ms) | Flash (ms) | Speedup | VRAM Std (MB) | VRAM Flash (MB)
# -----------------------------------------------------------------------------------
# 512 | 0.82 | 0.31 | 2.65x | 48 | 12
# 1024 | 2.45 | 0.58 | 4.22x | 192 | 24
# 2048 | 9.12 | 1.12 | 8.14x | 768 | 48
# 4096 | 35.80 | 2.21 | 16.20x | 3072 | 96
# 8192 | 144.20 | 4.38 | 32.92x | 12288 | 192
# Flash Attention scales LINEARLY: at seq=8192 uses 64x less VRAM!
Gradient Checkpointing and Gradient Accumulation
When VRAM is the bottleneck during training, two complementary techniques allow training with larger batches without upgrading hardware:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint_sequential
import gc
# ============================================================
# GRADIENT CHECKPOINTING
# ============================================================
# Idea: instead of saving all intermediate activations for backward pass,
# recompute them on the fly (trade-off: +33% compute, -50-70% memory)
class CheckpointedTransformerBlock(nn.Module):
"""Transformer block with gradient checkpointing."""
def __init__(self, d_model=768, n_heads=12):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model * 4), nn.GELU(),
nn.Linear(d_model * 4, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
def _attn_block(self, x):
attn_out, _ = self.attn(x, x, x)
return self.norm1(x + attn_out)
def _ff_block(self, x):
return self.norm2(x + self.ff(x))
def forward(self, x):
# Gradient checkpointing: each sub-module is recomputed
# during backward instead of being saved
x = torch.utils.checkpoint.checkpoint(self._attn_block, x, use_reentrant=False)
x = torch.utils.checkpoint.checkpoint(self._ff_block, x, use_reentrant=False)
return x
def enable_gradient_checkpointing_hf(model):
"""Enable gradient checkpointing on HuggingFace models."""
model.gradient_checkpointing_enable()
print(f"Gradient checkpointing enabled on {type(model).__name__}")
# Gradient Checkpointing Benchmark
def compare_checkpointing(seq_len=2048, batch_size=8, d_model=768,
n_layers=12, n_heads=12, device="cuda"):
"""Compare training with and without gradient checkpointing."""
class SimpleTransformer(nn.Module):
def __init__(self, use_checkpoint=False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
CheckpointedTransformerBlock(d_model, n_heads)
for _ in range(n_layers)
])
self.head = nn.Linear(d_model, 1000)
def forward(self, x):
for block in self.blocks:
if self.use_checkpoint:
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
else:
x = block(x)
return self.head(x[:, 0])
results = {}
for use_ckpt in [False, True]:
name = "with checkpointing" if use_ckpt else "without checkpointing"
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
model = SimpleTransformer(use_checkpoint=use_ckpt).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
x = torch.randn(batch_size, seq_len, d_model, device=device)
labels = torch.randint(0, 1000, (batch_size,), device=device)
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in range(10):
optimizer.zero_grad()
out = model(x)
loss = nn.CrossEntropyLoss()(out, labels)
loss.backward()
optimizer.step()
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = (time.perf_counter() - t0) / 10 * 1000
vram = torch.cuda.max_memory_allocated() / (1024**2) if torch.cuda.is_available() else 0
results[name] = {"vram_mb": round(vram, 1), "step_ms": round(elapsed, 1)}
print(f"{name}: VRAM={vram:.0f}MB, Step={elapsed:.1f}ms")
return results
# Typical results (12-layer Transformer, seq=2048, BS=8, RTX 3090):
# Without checkpointing: VRAM=18.4GB, Step=285ms
# With checkpointing: VRAM= 7.8GB, Step=378ms (-58% VRAM, +33% compute)
# ============================================================
# GRADIENT ACCUMULATION
# ============================================================
def train_with_gradient_accumulation(
model, optimizer, train_loader, criterion,
accumulation_steps: int = 4,
device: str = "cuda"
):
"""
Gradient accumulation: simulates batch_size * accumulation_steps
with the memory of batch_size.
Useful when actual batch_size is too small for optimal convergence.
"""
model = model.to(device).train()
optimizer.zero_grad()
for step, (imgs, labels) in enumerate(train_loader):
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(imgs)
# Divide loss by accumulation steps (maintains correct scale)
loss = criterion(output, labels) / accumulation_steps
loss.backward()
# Update weights every N steps
if (step + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
effective_batch = imgs.size(0) * accumulation_steps
print(f"Step {(step+1)//accumulation_steps} | "
f"Effective batch: {effective_batch} | Loss: {loss.item()*accumulation_steps:.4f}")
torch.compile: Graph Optimization
torch.compile (PyTorch 2.0+) compiles the model into optimized kernels via Triton or other backends. It is the simplest optimization to apply: a single line of code can yield 1.5-2.5x speedup on inference.
import torch
from torchvision import models
import time, numpy as np
def benchmark_torch_compile():
device = "cuda" if torch.cuda.is_available() else "cpu"
# ============================================================
# COMPILATION MODES
# ============================================================
# "default": Balance compile time / speedup
# "reduce-overhead": Minimize overhead, optimal for small batches
# "max-autotune": Maximum speed (much longer compile time, ~5-10 min)
# "inductor": Default backend (Triton on CUDA, C++ on CPU)
model_fp32 = models.resnet50(pretrained=False).to(device).eval()
# Eager compilation (default)
model_compiled_default = torch.compile(
models.resnet50(pretrained=False).to(device).eval(),
mode="default"
)
# Compilation for maximum speed
model_compiled_max = torch.compile(
models.resnet50(pretrained=False).to(device).eval(),
mode="max-autotune",
fullgraph=True # Avoid graph breaks for maximum speedup
)
x = torch.randn(32, 3, 224, 224, device=device)
def time_model(model, x, n=100):
"""Benchmark with warmup."""
# Warmup (especially important for torch.compile)
with torch.no_grad():
for _ in range(20):
model(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
latencies = []
with torch.no_grad():
for _ in range(n):
t0 = time.perf_counter()
model(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
latencies.append((time.perf_counter() - t0) * 1000)
return np.mean(latencies)
ms_eager = time_model(model_fp32, x)
ms_default = time_model(model_compiled_default, x)
print(f"Eager (FP32): {ms_eager:.2f} ms")
print(f"Compiled default: {ms_default:.2f} ms ({ms_eager/ms_default:.2f}x speedup)")
# BF16 + compile: multiplicative effect
model_bf16_compiled = torch.compile(
models.resnet50(pretrained=False).to(device).eval(),
mode="default"
)
x_bf16 = x.to(torch.bfloat16)
model_bf16_compiled = model_bf16_compiled.to(torch.bfloat16)
ms_bf16_compiled = time_model(model_bf16_compiled, x_bf16)
print(f"BF16 + Compiled: {ms_bf16_compiled:.2f} ms ({ms_eager/ms_bf16_compiled:.2f}x speedup)")
# Typical results on RTX 4090:
# Eager FP32: 12.4 ms/step (BS=32)
# Compiled default: 7.8 ms/step (1.59x)
# BF16 + Compiled: 5.1 ms/step (2.43x)
benchmark_torch_compile()
KV Cache: Optimization for LLM Autoregressive Inference
In autoregressive models, each generated token must attend to all previous tokens. Without optimizations, keys (K) and values (V) are recomputed at each step — with O(n^2) complexity for a sequence of n tokens. The KV Cache saves K and V from each layer after each step, reducing generation cost from O(n^2) to O(n).
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
# ============================================================
# TRANSFORMER WITH KV CACHE
# ============================================================
class CachedMultiHeadAttention(nn.Module):
"""
Multi-head attention with KV cache for autoregressive generation.
The cache avoids recomputing K, V for past tokens.
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.scale = self.d_head ** -0.5
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(
self,
x: torch.Tensor,
kv_cache: Optional[Tuple] = None
) -> Tuple[torch.Tensor, Tuple]:
B, T, D = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
# Concatenate with existing cache
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
# Attention (automatic Flash Attention with PyTorch 2.0+)
out = F.scaled_dot_product_attention(q, k, v, is_causal=(kv_cache is None))
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.out_proj(out), (k, v) # Return output + new cache
class CachedTransformerDecoder(nn.Module):
"""Transformer Decoder with KV cache for efficient generation."""
def __init__(self, vocab_size: int, d_model: int = 512,
n_heads: int = 8, n_layers: int = 6):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(2048, d_model)
self.layers = nn.ModuleList([
CachedMultiHeadAttention(d_model, n_heads)
for _ in range(n_layers)
])
self.norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_layers)])
self.head = nn.Linear(d_model, vocab_size)
self.n_layers = n_layers
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0
) -> torch.Tensor:
"""
Autoregressive generation with KV cache.
Each step uses the cache from previous tokens.
"""
B, T = input_ids.shape
device = input_ids.device
# Process prompt (prefill)
x = self.embed(input_ids)
positions = torch.arange(T, device=device).unsqueeze(0)
x = x + self.pos_embed(positions)
# Initialize cache for each layer
kv_caches = [None] * self.n_layers
for i, (layer, norm) in enumerate(zip(self.layers, self.norms)):
x_norm = norm(x)
attn_out, kv_caches[i] = layer(x_norm, kv_caches[i])
x = x + attn_out
# Token-by-token generation (using cache)
generated = []
for step in range(max_new_tokens):
last_token = input_ids[:, -1:] if step == 0 else new_token
x_new = self.embed(last_token)
pos = torch.tensor([[T + step]], device=device)
x_new = x_new + self.pos_embed(pos)
for i, (layer, norm) in enumerate(zip(self.layers, self.norms)):
x_norm = norm(x_new)
attn_out, kv_caches[i] = layer(x_norm, kv_caches[i])
x_new = x_new + attn_out
logits = self.head(x_new[:, -1, :]) / temperature
new_token = torch.multinomial(torch.softmax(logits, -1), 1)
generated.append(new_token)
return torch.cat(generated, dim=1)
# Benchmark KV cache
def benchmark_generation(model, vocab_size=32000, seq_len=128,
max_new=50, device="cuda"):
model = model.to(device).eval()
input_ids = torch.randint(0, vocab_size, (1, seq_len), device=device)
t0 = time.perf_counter()
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=max_new)
t_cached = (time.perf_counter() - t0) * 1000
tokens_per_sec = max_new / (t_cached / 1000)
print(f"With KV Cache: {t_cached:.1f}ms total, {tokens_per_sec:.1f} token/s")
Systematic Comparison: From 48GB to 8GB RTX
We summarize all the optimizations seen in the series, applying them progressively to a base model and showing the accuracy/memory/speed trade-off.
Full Comparison: Llama-3.1-8B on RTX 3090 (24GB)
| Configuration | VRAM | Throughput | HellaSwag | Perplexity | Notes |
|---|---|---|---|---|---|
| BF16 baseline | 16.0 GB | 38 t/s | 82.1% | 6.14 | Reference benchmark |
| + Flash Attention 2 | 14.2 GB | 52 t/s | 82.1% | 6.14 | -11% VRAM, +37% speed |
| + torch.compile | 14.2 GB | 68 t/s | 82.1% | 6.14 | +31% on Flash Attention |
| INT8 (bitsandbytes) | 8.5 GB | 35 t/s | 81.8% | 6.21 | -47% VRAM, -0.3% acc |
| INT4 NF4 (bnb) | 4.9 GB | 42 t/s | 81.2% | 6.47 | -69% VRAM, -0.9% acc |
| GPTQ INT4 | 4.8 GB | 55 t/s | 81.5% | 6.39 | -70% VRAM, -0.6% acc |
| AWQ INT4 | 4.7 GB | 52 t/s | 81.6% | 6.35 | -71% VRAM, -0.5% acc |
| GGUF Q4_K_M (CPU) | 0 VRAM (5 GB RAM) | 18 t/s | 81.3% | 6.42 | No GPU required |
Indicative values on RTX 3090 (24GB VRAM). Throughput measured with batch=1, seq=512.
Decision Guide: Which Optimization for Which Scenario
# DECISION TREE FOR DL OPTIMIZATION
def recommend_optimization(
vram_available_gb: float,
task: str, # "training" | "inference" | "edge"
accuracy_critical: bool,
hardware: str # "server_gpu" | "consumer_gpu" | "cpu" | "edge"
) -> dict:
"""
Recommends the most appropriate optimizations for your scenario.
"""
recommendations = []
priority = []
# === ALWAYS DO (zero or near-zero cost) ===
priority.append("1. Mixed Precision (BF16/FP16): ALWAYS enable on Ampere+ GPU")
priority.append("2. Flash Attention: enable if seq_len > 512")
priority.append("3. torch.compile: enable if PyTorch 2.0+, +30-50% inference speedup")
priority.append("4. KV Cache: ALWAYS enable for LLM autoregressive generation")
if task == "training":
if vram_available_gb < 24:
priority.append("5. Gradient Checkpointing: -50% VRAM, +33% compute")
priority.append("6. Gradient Accumulation: simulate larger batches")
if hardware in ["consumer_gpu", "edge"]:
priority.append("7. QLoRA: fine-tuning with INT4 + LoRA on consumer GPU")
if task in ["inference", "edge"]:
if not accuracy_critical:
if hardware == "server_gpu":
priority.append("5. GPTQ INT4: maximum throughput on NVIDIA GPU")
elif hardware in ["consumer_gpu", "cpu"]:
priority.append("5. AWQ INT4 or GGUF Q4_K_M: for heterogeneous hardware")
elif hardware == "edge":
priority.append("5. GGUF Q3_K_M or Q4_K_M: for Raspberry Pi / embedded")
else:
priority.append("5. INT8 (bitsandbytes): minimal accuracy loss")
if vram_available_gb < 16:
priority.append("6. ONNX Export: reduce runtime overhead +20-40%")
priority.append("7. Consider distillation toward smaller model")
print("=== OPTIMIZATION RECOMMENDATIONS ===")
for p in priority:
print(f" {p}")
return {"priorities": priority}
# Examples:
print("--- Scenario 1: Fine-tuning on RTX 4080 (16GB) ---")
recommend_optimization(16, "training", True, "consumer_gpu")
print("\n--- Scenario 2: Inference on Raspberry Pi ---")
recommend_optimization(0, "inference", False, "edge")
print("\n--- Scenario 3: Production on A100 (80GB) ---")
recommend_optimization(80, "inference", True, "server_gpu")
Optimization Summary: Expected Impact
| Technique | VRAM Saving | Speedup | Acc Loss | Complexity |
|---|---|---|---|---|
| Mixed Precision BF16 | -50% | 2-3x | 0% | Low (1 line) |
| Flash Attention 2 | -50-90% | 2-8x | 0% | Low (1 line) |
| torch.compile | 0% | 1.5-2.5x | 0% | Low (1 line) |
| KV Cache | +VRAM | 10-50x gen | 0% | Low |
| Gradient Checkpointing | -50-70% | -0.7x | 0% | Low |
| INT8 Quantization | -50% | 0.9-1.1x | 0-0.5% | Low |
| INT4 GPTQ/AWQ | -75% | 1.3-1.8x | 0.5-1.5% | Medium |
| Distillation | -70-90% | 5-20x | 5-15% | High |
| Structured Pruning | -30-70% | 2-5x | 2-10% | High |
Series Conclusions
We have covered the entire Advanced Deep Learning and Edge Deployment series: from attention mechanisms in Transformers to fine-tuning with LoRA, from GPTQ quantization to structured pruning, from distillation to Vision Transformers, from NAS to edge deployment with Raspberry Pi and Jetson, from Ollama to this final benchmark.
The central message is clear: there is no single "best" technique. The optimal choice always depends on context — available hardware, accuracy requirements, latency target, operational costs. But with the systematic benchmarking framework presented in this article, you can measure instead of guess, and make informed decisions.
The 2026 trend is clear: models are moving toward the edge. Gartner 2027 predicts SLMs will surpass cloud LLMs 3x in usage. The techniques in this series — quantization, distillation, edge deployment, Ollama — are not academic niches: they are the foundational skills for anyone wanting to work with AI in the coming years.
Series Summary: Advanced Deep Learning
- Article 1: Attention Mechanism in Transformers
- Article 2: Fine-tuning with LoRA and QLoRA
- Article 3: GPTQ, AWQ, INT8 Quantization
- Article 4: Knowledge Distillation
- Article 5: Neural Network Pruning
- Article 6: Vision Transformer (ViT)
- Article 7: Neural Architecture Search
- Article 8: Deep Learning on Edge Devices
- Article 9: Ollama and Local LLMs
- Article 10 (this): Benchmarks and Optimization
Related series: MLOps | Computer Vision | AI Engineering







