Knowledge Distillation: Comprimere Modelli Complessi
GPT-4 e troppo grande per girare sul tuo laptop. Un ResNet-152 e troppo lento per il tuo dispositivo mobile. Eppure la conoscenza acquisita da questi modelli giganteschi attraverso settimane di training su cluster GPU può essere trasferita a modelli molto più piccoli e veloci, con una perdita di accuratezza sorprendentemente modesta. Questa e la promessa della Knowledge Distillation.
Proposta da Hinton, Vinyals e Dean nel 2015, la distillazione e diventata una delle tecniche più potenti e versatili nel toolkit del deep learning moderno. Il principio e elegante: un modello grande (teacher) guida il training di un modello più piccolo (student) non semplicemente con le etichette hard (0/1), ma con le soft probabilities del teacher — distribuzioni di probabilità ricche di informazioni sulle similarità tra classi. Questa informazione implicita — la "dark knowledge" — e ciò che rende la distillazione cosi efficace.
In questa guida esploriamo la distillazione in profondità: dalla teoria originale alle varianti moderne (feature distillation, attention transfer, self-distillation, LLM distillation), con implementazioni complete in PyTorch, best practices per la produzione e un caso di studio reale.
Cosa Imparerai
- Teoria della distillazione: soft labels, temperatura e dark knowledge
- Implementazione della distillazione standard con PyTorch
- Feature Distillation: trasferire rappresentazioni intermedie
- Attention Transfer: distillare le attention maps dei Transformer
- Self-Distillation e Born Again Networks
- Distillazione offline vs online vs self
- Distillazione per LLM: da modelli grandi a modelli edge
- Combinare distillazione con quantizzazione e pruning
- Best practices, errori comuni e metriche di valutazione
- Caso di studio: DistilBERT step by step
Il Principio della Distillazione: Dark Knowledge
Un classificatore standard addestrato su hard labels (one-hot) usa informazioni molto scarse: la classe corretta ha probabilità 1, tutte le altre 0. Ma un modello ben addestrato sa molto di più: sa che un gatto e più simile a un cane che a un'automobile. Queste informazioni sono contenute nella distribuzione di probabilità del teacher, anche quando la risposta corretta ha probabilità quasi 1.
Il trucco della distillazione e usare una temperatura T per "ammorbidire" le probabilità del teacher, amplificando le differenze tra classi improbabili ma informative. Con T=1 si ha la distribuzione originale; con T alto (es. T=4) le probabilità diventano più uniformi, rivelando le relazioni di similarità implicite. Questo meccanismo e chiamato dark knowledge — la conoscenza nascosta nei logits del modello che una semplice label binaria non può catturare.
# Visualizzazione effetto della temperatura sulla dark knowledge
import torch
import torch.nn.functional as F
import numpy as np
# Supponiamo che il teacher produca questi logits per un campione
# con classe vera = 0 (gatto)
teacher_logits = torch.tensor([8.2, 2.1, 1.8, 0.5, 0.3, -0.2, -0.5, -0.8, -1.1, -1.5])
# Classi ipotetiche: [gatto, cane, felino, leone, volpe, auto, aereo, nave, treno, barca]
classi = ["gatto", "cane", "felino", "leone", "volpe", "auto", "aereo", "nave", "treno", "barca"]
print("Effetto della temperatura sulle soft probabilities:")
print("-" * 75)
for T in [1, 2, 4, 8, 20]:
probs = F.softmax(teacher_logits / T, dim=0)
entropia = -(probs * probs.log()).sum().item()
print(f"T={T:2d}: p(gatto)={probs[0]:.4f}, "
f"p(cane)={probs[1]:.4f}, p(felino)={probs[2]:.4f}, "
f"entropia={entropia:.3f}")
# Output:
# T= 1: p(gatto)=0.9833, p(cane)=0.0106, p(felino)=0.0079, entropia=0.123
# T= 2: p(gatto)=0.9175, p(cane)=0.0440, p(felino)=0.0297, entropia=0.424
# T= 4: p(gatto)=0.7562, p(cane)=0.1168, p(felino)=0.0913, entropia=0.895
# T= 8: p(gatto)=0.5756, p(cane)=0.1572, p(felino)=0.1343, entropia=1.387
# T=20: p(gatto)=0.3520, p(cane)=0.1668, p(felino)=0.1600, entropia=1.944
#
# Con T alta, il teacher rivela che "gatto" e molto simile a "cane" e "felino"
# ma niente a che fare con "auto" o "aereo".
# Questa e la DARK KNOWLEDGE che lo student impara!
print("\nAnalisi della dark knowledge:")
probs_t1 = F.softmax(teacher_logits / 1, dim=0)
probs_t4 = F.softmax(teacher_logits / 4, dim=0)
for i, classe in enumerate(classi):
print(f" {classe:8s}: T=1: {probs_t1[i]:.4f}, T=4: {probs_t4[i]:.4f}")
La Matematica della Distillazione
La loss di distillazione combina due termini con un hyperparameter di bilanciamento alpha:
- L_distill: KL divergence tra le soft probabilities dello student e del teacher (a temperatura T), moltiplicata per T² per compensare la riduzione di grandezza del gradiente
- L_student: Cross-entropy standard tra predizioni dello student e le etichette hard
La formula finale e: L = alpha * T² * KL(student_soft || teacher_soft) + (1-alpha) * CE(student, labels)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
# ============================================================
# DISTILLATION LOSS - Implementazione completa
# ============================================================
class DistillationLoss(nn.Module):
"""
Loss per Knowledge Distillation (Hinton et al., 2015).
L = alpha * T^2 * KL(student_soft || teacher_soft) + (1-alpha) * CE(student, labels)
Il fattore T^2 e fondamentale: quando si usa T > 1,
i gradienti si riducono di 1/T^2. Moltiplicando per T^2
si compensa questo effetto, mantenendo scale dei gradienti
coerenti tra la KD loss e la CE loss.
"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
"""
temperature: scala le probabilità soft (tipico: 2-8)
alpha: peso della distillation loss (tipico: 0.5-0.9)
"""
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor) -> dict:
# Soft probabilities a temperatura T
# NB: log_softmax per student (richiesto da KLDivLoss)
student_soft = F.log_softmax(student_logits / self.T, dim=1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=1)
# KL divergence * T^2 per compensare la riduzione del gradiente
loss_distill = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
# Cross-entropy standard con hard labels
loss_student = self.ce_loss(student_logits, labels)
# Combinazione pesata
total_loss = self.alpha * loss_distill + (1 - self.alpha) * loss_student
return {
'total': total_loss,
'distill': loss_distill.detach(),
'student': loss_student.detach()
}
# ============================================================
# MODELLI TEACHER e STUDENT
# ============================================================
def create_teacher_student(n_classes: int = 100):
"""
Teacher: ResNet-50 (~25M parametri) - pre-trained su ImageNet
Student: MobileNetV3-Small (~2.5M parametri) - 10x più piccolo
"""
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, n_classes)
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, n_classes
)
total_teacher = sum(p.numel() for p in teacher.parameters())
total_student = sum(p.numel() for p in student.parameters())
flops_teacher = 4.1e9 # Approx FLOPs ResNet-50
flops_student = 0.056e9 # Approx FLOPs MobileNetV3-S
print(f"Teacher (ResNet-50): {total_teacher:,} param, {flops_teacher/1e9:.1f}G FLOPs")
print(f"Student (MobileNetV3): {total_student:,} param, {flops_student*1000:.0f}M FLOPs")
print(f"Fattore compressione: {total_teacher/total_student:.1f}x param, "
f"{flops_teacher/flops_student:.0f}x FLOPs")
return teacher, student
# ============================================================
# TRAINING LOOP CON DISTILLAZIONE
# ============================================================
def train_with_distillation(
teacher: nn.Module,
student: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 50,
temperature: float = 4.0,
alpha: float = 0.7,
lr: float = 1e-3,
device: str = "cuda"
):
teacher = teacher.to(device).eval() # Teacher: SOLO inference, no backprop!
student = student.to(device)
criterion = DistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.AdamW(student.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
best_acc = 0.0
history = {'train_loss': [], 'val_acc': [], 'distill_loss': [], 'student_loss': []}
for epoch in range(n_epochs):
student.train()
total_loss = distill_loss_sum = student_loss_sum = 0.0
n_batches = 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass teacher (CRITICO: no gradients, risparmia memoria!)
with torch.no_grad():
teacher_logits = teacher(imgs)
# Forward pass student (con gradients)
student_logits = student(imgs)
# Loss combinata
losses = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
losses['total'].backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
optimizer.step()
total_loss += losses['total'].item()
distill_loss_sum += losses['distill'].item()
student_loss_sum += losses['student'].item()
n_batches += 1
scheduler.step()
# Validation
student.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = student(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_total = total_loss / n_batches
history['train_loss'].append(avg_total)
history['val_acc'].append(val_acc)
history['distill_loss'].append(distill_loss_sum / n_batches)
history['student_loss'].append(student_loss_sum / n_batches)
if val_acc > best_acc:
best_acc = val_acc
torch.save(student.state_dict(), "best_student.pth")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {avg_total:.4f} "
f"| Val Acc: {val_acc:.4f} | Best: {best_acc:.4f}")
print(f"\nMiglior accuracy student: {best_acc:.4f}")
return history, best_acc
# Risultati tipici CIFAR-100:
# ResNet-50 teacher: 78.2% Top-1
# MobileNetV3-S senza KD: 67.1% Top-1
# MobileNetV3-S con KD: 71.4% Top-1 (+4.3%)
# MobileNetV3-S con KD+feat: 73.2% Top-1 (+6.1%)
# Compression: 10x param, 73x FLOPs
Feature Distillation: Trasferire le Rappresentazioni Interne
La distillazione sulle soft labels trasferisce solo l'output finale del teacher. La Feature Distillation va oltre: forza lo student a replicare anche le rappresentazioni intermedie del teacher — i feature map a diversi livelli della rete. Questo e particolarmente efficace quando teacher e student hanno architetture molto diverse (es. CNN teacher, ViT student) e quando il task richiede feature spaziali ricche.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# ============================================================
# FEATURE EXTRACTOR via Forward Hooks
# ============================================================
class FeatureExtractor:
"""Cattura le feature di layer specifici tramite forward hooks."""
def __init__(self, model: nn.Module, layer_names: list):
self.features = {}
self.hooks = []
for name, module in model.named_modules():
if name in layer_names:
hook = module.register_forward_hook(
lambda m, inp, out, n=name: self.features.update({n: out})
)
self.hooks.append(hook)
def get_features(self) -> list:
return list(self.features.values())
def clear(self):
self.features.clear()
def remove(self):
"""Rimuovi hooks per evitare memory leaks."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
# ============================================================
# FEATURE DISTILLATION LOSS
# ============================================================
class FeatureDistillationLoss(nn.Module):
"""
Loss che combina:
1. KD loss standard (soft labels output)
2. Feature Matching Loss (MSE tra feature intermedie normalizzate)
3. Relation-Based Loss (distanze relative tra sample nel batch)
"""
def __init__(self, student_channels: list, teacher_channels: list,
temperature: float = 4.0, alpha: float = 0.4,
beta: float = 0.4, gamma: float = 0.2):
"""
alpha: peso KD loss
beta: peso feature matching loss
gamma: peso CE loss standard
(alpha + beta + gamma deve essere 1.0)
"""
super().__init__()
assert abs(alpha + beta + gamma - 1.0) < 1e-6, "Pesi devono sommare a 1"
self.T = temperature
self.alpha = alpha
self.beta = beta
self.gamma = gamma
# Adattatori per allineare dimensioni teacher->student
# Esempio: teacher ha 2048 canali, student 96 -> adapter 1x1 conv
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Conv2d(t_ch, s_ch, 1, bias=False),
nn.BatchNorm2d(s_ch),
nn.ReLU(inplace=True)
)
for t_ch, s_ch in zip(teacher_channels, student_channels)
])
def forward(self, student_logits, teacher_logits, labels,
student_features: list, teacher_features: list):
# 1. KD Loss (soft labels)
kl = nn.KLDivLoss(reduction='batchmean')
loss_kd = kl(
F.log_softmax(student_logits / self.T, dim=1),
F.softmax(teacher_logits / self.T, dim=1)
) * self.T ** 2
# 2. CE Loss standard
loss_ce = F.cross_entropy(student_logits, labels)
# 3. Feature Matching Loss
loss_feat = torch.tensor(0.0, device=student_logits.device)
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
# Adatta canali del teacher allo student
t_adapted = self.adapters[i](t_feat.detach())
# Allinea risoluzione spaziale se necessario
if s_feat.shape[2:] != t_adapted.shape[2:]:
t_adapted = F.interpolate(
t_adapted, size=s_feat.shape[2:],
mode='bilinear', align_corners=False
)
# Normalizza le feature (cosine similarity invece di MSE)
s_norm = F.normalize(s_feat.view(s_feat.size(0), -1), dim=1)
t_norm = F.normalize(t_adapted.view(t_adapted.size(0), -1), dim=1)
# MSE tra feature normalizzate
loss_feat = loss_feat + F.mse_loss(s_norm, t_norm)
loss_feat = loss_feat / max(len(student_features), 1)
total = self.alpha * loss_kd + self.beta * loss_feat + self.gamma * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'feat': loss_feat.detach()
}
# Configurazione per ResNet-50 teacher -> MobileNetV3-S student
# Layer teacher: [layer2, layer3, layer4] -> Canali: [512, 1024, 2048]
# Layer student: [features.4, features.9, features.12] -> Canali: [40, 96, 576]
teacher = models.resnet50(pretrained=True)
student = models.mobilenet_v3_small(pretrained=False)
teacher_layers = ['layer2', 'layer3', 'layer4']
student_layers = ['features.4', 'features.9', 'features.12']
teacher_channels = [512, 1024, 2048]
student_channels = [40, 96, 576]
teacher_extractor = FeatureExtractor(teacher, teacher_layers)
student_extractor = FeatureExtractor(student, student_layers)
feat_criterion = FeatureDistillationLoss(
student_channels=student_channels,
teacher_channels=teacher_channels,
temperature=4.0, alpha=0.4, beta=0.4, gamma=0.2
)
print("Feature Distillation setup completato!")
print(f"Teacher layers: {teacher_layers}")
print(f"Student layers: {student_layers}")
Attention Transfer per Transformer e Vision Transformer
I Vision Transformer producono attention maps esplicite che possono essere direttamente distillate. DeiT (Data-efficient Image Transformer) usa questo approccio con un token di distillazione speciale. L'Attention Transfer (Zagoruyko & Komodakis, 2017) estende il concetto anche alle CNN, costruendo mappe di attenzione dalle attivazioni dei layer convoluzionali.
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# ============================================================
# ATTENTION TRANSFER (AT) per CNN
# ============================================================
class AttentionTransferLoss(nn.Module):
"""
Attention Transfer (Zagoruyko & Komodakis, 2017).
Forza lo student a replicare le attention maps del teacher.
Efficace per transfer tra architetture diverse (CNN <-> ViT).
"""
def __init__(self, beta: float = 1000.0):
super().__init__()
self.beta = beta
def attention_map(self, features: torch.Tensor) -> torch.Tensor:
"""
Calcola mappa di attention come norma L2 quadrata delle attivazioni.
Input features: [B, C, H, W]
Output: [B, H*W] normalizzato (attention map piatta)
"""
# Somma sui canali -> [B, H, W]
attention = features.pow(2).sum(dim=1)
# Appiattisci -> [B, H*W]
attention = attention.view(attention.size(0), -1)
# Normalizza L2 per ogni sample nel batch
return F.normalize(attention, p=2, dim=1)
def forward(self, student_features: list, teacher_features: list) -> torch.Tensor:
"""Calcola AT loss su più livelli."""
total_loss = torch.tensor(0.0)
for s_feat, t_feat in zip(student_features, teacher_features):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat).detach()
# Allinea dimensioni spaziali se necessario
if s_attn.shape != t_attn.shape:
s_h = int(s_feat.shape[2] * s_feat.shape[3])
t_h = int(t_feat.shape[2] * t_feat.shape[3])
# Usa interpolazione sull'attention map 2D
s_2d = s_feat.pow(2).mean(1, keepdim=True)
t_2d = t_feat.pow(2).mean(1, keepdim=True)
t_2d = F.interpolate(t_2d, size=s_feat.shape[2:], mode='bilinear')
s_attn = F.normalize(s_2d.view(s_2d.size(0), -1), p=2, dim=1)
t_attn = F.normalize(t_2d.view(t_2d.size(0), -1), p=2, dim=1).detach()
total_loss = total_loss + (s_attn - t_attn).pow(2).mean()
return self.beta * total_loss / max(len(student_features), 1)
# ============================================================
# DeiT-STYLE: Distillation Token per Vision Transformer
# ============================================================
class ViTWithDistillationToken(nn.Module):
"""
Aggiunge un distillation token a un ViT standard.
Come in DeiT: il token impara a replicare le predizioni
di un teacher CNN (es. RegNet, ResNet).
Durante inference: media CLS token + dist token.
Durante training: loss su entrambi i token.
"""
def __init__(self, vit_model: nn.Module, n_classes: int, d_model: int = 384):
super().__init__()
self.vit = vit_model
# Token di distillazione learnable
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.dist_token, std=0.02)
# Distillation head (separato dal CLS head)
self.dist_head = nn.Linear(d_model, n_classes)
def forward(self, x: torch.Tensor, return_dist: bool = False):
# Ottieni feature dal ViT
features = self.vit.forward_features(x)
# CLS prediction (predizione principale)
cls_pred = self.vit.head(features[:, 0])
# Dist token prediction (guida del teacher)
dist_pred = self.dist_head(features[:, 1]) # Assumendo dist_token al pos 1
if self.training:
return cls_pred, dist_pred
else:
# Inference: media delle due predizioni
return (cls_pred + dist_pred) / 2.0
def deit_distillation_loss(cls_pred, dist_pred, teacher_pred, labels,
alpha: float = 0.5, temperature: float = 3.0):
"""
Loss DeiT: combina CE hard labels + KD dal teacher CNN.
"""
# Hard label loss sul CLS token
loss_cls = F.cross_entropy(cls_pred, labels)
# Soft label loss sul distillation token
loss_dist = F.kl_div(
F.log_softmax(dist_pred / temperature, dim=1),
F.softmax(teacher_pred / temperature, dim=1),
reduction='batchmean'
) * temperature ** 2
return alpha * loss_dist + (1 - alpha) * loss_cls
Distillazione per LLM: da Modelli Grandi a Piccoli
La distillazione per LLM segue gli stessi principi ma con alcune secificità importanti. Il vocabolario e enorme (32K-128K token), i modelli teacher e student devono essere compatibili a livello di tokenizer, e la loss opera a livello di ogni token nella sequenza. DistilBERT, DistilGPT2 e la famiglia Phi di Microsoft sono esempi di successo.
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
AutoModelForSequenceClassification
)
import torch
import torch.nn.functional as F
# ============================================================
# DISTILLAZIONE LLM: causal language modeling
# ============================================================
def distill_llm_batch(
teacher_model,
student_model,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 2.0,
alpha: float = 0.7
) -> dict:
"""
Distillazione LLM per next-token prediction.
Funziona per GPT-style (causal) e BERT-style (masked).
teacher_model: modello grande (es. Llama-3-8B)
student_model: modello piccolo (es. Llama-3-1B)
alpha: peso KD loss (1-alpha = peso CE loss standard)
"""
device = next(student_model.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Teacher inference (no gradients, può essere su device diverso)
with torch.no_grad():
teacher_outputs = teacher_model(
input_ids, attention_mask=attention_mask
)
teacher_logits = teacher_outputs.logits # [B, seq_len, vocab_size]
# Student inference (con gradients)
student_outputs = student_model(
input_ids, attention_mask=attention_mask
)
student_logits = student_outputs.logits
# Shift per next-token prediction (esclude l'ultimo token come input)
shift_student = student_logits[:, :-1, :].contiguous()
shift_teacher = teacher_logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# Ridimensiona per calcolo per-token loss
B, S, V = shift_student.shape
shift_student_flat = shift_student.view(B * S, V)
shift_teacher_flat = shift_teacher.view(B * S, V)
shift_labels_flat = shift_labels.view(B * S)
# 1. KD Loss: KL divergence per ogni token
student_log_probs = F.log_softmax(shift_student_flat / temperature, dim=-1)
teacher_probs = F.softmax(shift_teacher_flat / temperature, dim=-1)
loss_kd = F.kl_div(student_log_probs, teacher_probs,
reduction='batchmean') * temperature ** 2
# 2. CE Loss standard (con label -100 per token da ignorare)
loss_ce = F.cross_entropy(
shift_student_flat, shift_labels_flat,
ignore_index=-100 # Padding tokens
)
total = alpha * loss_kd + (1 - alpha) * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'perplexity': torch.exp(loss_ce).detach()
}
# ============================================================
# PIPELINE DISTILLAZIONE LLM COMPLETA
# ============================================================
def setup_llm_distillation(
teacher_name: str = "meta-llama/Llama-3.1-8B",
student_name: str = "meta-llama/Llama-3.2-1B",
device: str = "cuda"
):
"""
Setup per distillare un LLM grande in uno più piccolo.
IMPORTANTE: teacher e student devono condividere il tokenizer
per avere distribuzioni compatibili sullo stesso vocabolario.
"""
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Teacher: carica in FP16 per risparmiare memoria
teacher = AutoModelForCausalLM.from_pretrained(
teacher_name,
torch_dtype=torch.float16,
device_map="auto" # Distribuisce su più GPU se disponibili
)
teacher.eval()
# Student: carica in FP32 per training stabile
student = AutoModelForCausalLM.from_pretrained(
student_name,
torch_dtype=torch.float32
).to(device)
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {teacher_params/1e9:.1f}B parametri")
print(f"Student: {student_params/1e9:.1f}B parametri")
print(f"Compressione: {teacher_params/student_params:.1f}x")
return teacher, student, tokenizer
print("Setup distillazione LLM completato!")
Self-Distillation e Born Again Networks
La self-distillazione e una variante sorprendente: il modello fa da teacher a se stesso. In Born Again Networks (BANs, Furlanello et al. 2018), si addestrano generazioni successive di modelli con la stessa architettura: ogni generazione usa la precedente come teacher. Il risultato e un miglioramento sistematico (+1-2% Top-1) senza aumentare la dimensione del modello.
import torch
import torch.nn as nn
import copy
# ============================================================
# BORN AGAIN NETWORKS (BANs)
# ============================================================
def born_again_training(model_factory, train_loader, val_loader,
n_generations: int = 3,
temperature: float = 4.0,
n_epochs: int = 30,
device: str = "cuda"):
"""
Allena N generazioni con la stessa architettura.
Gen 1: training standard con CE loss.
Gen 2+: distillazione dalla generazione precedente.
Risultati tipici CIFAR-100:
Gen 1: 67.1% (baseline)
Gen 2: 70.8% (+3.7%)
Gen 3: 72.1% (+5.0%)
Ensemble 1+2+3: 74.8% (+7.7%)
"""
criterion_kd = DistillationLoss(temperature=temperature, alpha=0.7)
criterion_ce = nn.CrossEntropyLoss()
all_models = []
results = []
# === Generazione 1: training standard ===
print("Gen 1: training standard...")
gen1 = model_factory().to(device)
opt1 = torch.optim.SGD(gen1.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch1 = torch.optim.lr_scheduler.MultiStepLR(opt1, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
gen1.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
opt1.zero_grad()
criterion_ce(gen1(imgs), labels).backward()
opt1.step()
sch1.step()
acc1 = _evaluate(gen1, val_loader, device)
results.append(acc1)
all_models.append(gen1)
print(f" Gen 1 val acc: {acc1:.4f}")
teacher = copy.deepcopy(gen1).eval()
# === Generazioni successive con distillazione ===
for gen_idx in range(2, n_generations + 1):
print(f"Gen {gen_idx}: KD da gen {gen_idx-1}...")
student = model_factory().to(device)
opt = torch.optim.SGD(student.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
student.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
with torch.no_grad():
t_logits = teacher(imgs)
s_logits = student(imgs)
losses = criterion_kd(s_logits, t_logits, labels)
opt.zero_grad()
losses['total'].backward()
opt.step()
sch.step()
acc = _evaluate(student, val_loader, device)
results.append(acc)
all_models.append(student)
print(f" Gen {gen_idx} val acc: {acc:.4f}")
teacher = copy.deepcopy(student).eval()
# Ensemble di tutti i modelli (upper bound)
ensemble_acc = _ensemble_evaluate(all_models, val_loader, device)
print(f"\nEnsemble {n_generations} gen: {ensemble_acc:.4f}")
return results, all_models
def _evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
correct += (model(imgs).argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
def _ensemble_evaluate(models, loader, device):
"""Ensemble averaging delle predizioni."""
for m in models:
m.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
logits_sum = torch.stack([m(imgs) for m in models]).mean(0)
correct += (logits_sum.argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
Risultati Tipici di Distillazione (Benchmark 2024-2025)
| Task | Teacher | Student | Senza KD | Con KD | Teacher | Compressione |
|---|---|---|---|---|---|---|
| CIFAR-100 | ResNet-50 | MobileNetV3-S | 67.1% | 73.2% | 78.2% | 10x param |
| ImageNet | ViT-L/16 | DeiT-S | 79.8% | 83.1% | 87.1% | 5x param |
| GLUE (NLP) | BERT-Large | DistilBERT | 83.2% | 86.4% | 89.2% | 2x param, 2x speed |
| SQuAD (QA) | RoBERTa-L | DistilRoBERTa | 82.1% | 85.8% | 90.4% | 2x param |
| LLM (perplexity) | Llama 3.1 8B | Llama 3.2 1B | 8.24 PPL | 7.81 PPL | 6.12 PPL | 8x param |
KD tipicamente recupera il 70-85% del gap tra student e teacher con 2-10x meno parametri.
Pipeline Produzione: Distillazione + Quantizzazione
Il workflow più potente per il deployment edge combina distillazione e quantizzazione in sequenza: prima si crea lo student con KD (mantiene alta accuratezza), poi si quantizza lo student (riduce dimensioni e aumenta velocità). La combinazione può ridurre un modello da 100x a 40x rispetto al teacher originale con solo 5-10% di perdita di accuratezza.
import torch
import torch.nn as nn
from torchvision import models
# ============================================================
# PIPELINE COMPLETA: Distillazione -> Quantizzazione -> ONNX
# ============================================================
def full_compression_pipeline(teacher_path: str, output_dir: str = "./compressed"):
"""
Pipeline completa per comprimere un modello per edge deployment.
Step 1: Carica teacher pre-trainato
Step 2: Distilla in student più piccolo
Step 3: Quantizza lo student (PTQ INT8)
Step 4: Esporta in ONNX per deployment cross-platform
"""
import os
os.makedirs(output_dir, exist_ok=True)
# STEP 1: Teacher
print("Step 1: Carico teacher...")
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(2048, 10) # 10 classi
# In produzione: teacher.load_state_dict(torch.load(teacher_path))
teacher.eval()
teacher_size_mb = sum(p.numel() * p.element_size()
for p in teacher.parameters()) / (1024**2)
print(f" Teacher: {teacher_size_mb:.1f} MB, "
f"{sum(p.numel() for p in teacher.parameters())/1e6:.1f}M param")
# STEP 2: Student (dopo distillazione)
print("Step 2: Student con distillazione (simulato con MobileNetV3)...")
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, 10
)
# In produzione: train_with_distillation(teacher, student, ...)
# student.load_state_dict(torch.load("best_student.pth"))
student_size_mb = sum(p.numel() * p.element_size()
for p in student.parameters()) / (1024**2)
print(f" Student: {student_size_mb:.1f} MB, "
f"{sum(p.numel() for p in student.parameters())/1e6:.1f}M param")
print(f" Riduzione rispetto teacher: {teacher_size_mb/student_size_mb:.1f}x")
# STEP 3: Quantizzazione INT8 (PTQ)
print("Step 3: Quantizzazione INT8...")
student.eval()
# Quantizzazione dinamica (più semplice, leggermente meno efficiente)
student_quant = torch.quantization.quantize_dynamic(
student,
{nn.Linear}, # Quantizza solo Linear (Conv2d richiede calibrazione)
dtype=torch.qint8
)
# Per quantizzazione statica (più efficiente):
# student.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# torch.quantization.prepare(student, inplace=True)
# calibrate(student, calib_loader) # Forward pass su dati di calibrazione
# torch.quantization.convert(student, inplace=True)
quant_size_mb = sum(p.numel() * p.element_size()
for p in student_quant.parameters()) / (1024**2)
print(f" Student INT8: ~{student_size_mb/4:.1f} MB (stima)")
print(f" Riduzione totale: ~{teacher_size_mb/(student_size_mb/4):.0f}x vs teacher")
# STEP 4: Export ONNX
print("Step 4: Export ONNX...")
dummy = torch.randn(1, 3, 224, 224)
onnx_path = f"{output_dir}/student_compressed.onnx"
torch.onnx.export(
student, # Usa FP32 per ONNX (compatibilità più ampia)
dummy,
onnx_path,
opset_version=13,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}},
export_params=True
)
print(f"\n=== RIEPILOGO PIPELINE ===")
print(f"Teacher (ResNet-50): {teacher_size_mb:.1f} MB")
print(f"Student KD (MobNetV3): {student_size_mb:.1f} MB ({teacher_size_mb/student_size_mb:.1f}x riduzione)")
print(f"Student INT8 (stimato): {student_size_mb/4:.1f} MB ({teacher_size_mb/(student_size_mb/4):.0f}x riduzione)")
print(f"ONNX salvato: {onnx_path}")
return student_quant, onnx_path
full_compression_pipeline("teacher_weights.pth")
Anti-Pattern nella Distillazione: Errori Comuni
- Temperature troppo alta o troppo bassa: T=1 equivale a hard labels. T troppo alta (>20) rende le soft labels quasi uniformi, perdendo il segnale. Esegui sempre un ablation study con T ∈ {2, 4, 6, 8} sul tuo dataset specifico.
- Teacher e student troppo diversi in capacità: se il gap e enorme (GPT-4 a 7B), la distillazione diretta e inefficace. Usa distillazione a cascata: GPT-4 -> 13B -> 7B -> 3B. Ogni step non dovrebbe superare 4-5x di riduzione.
- Ignorare la qualità del dataset di distillazione: la qualità del dataset su cui esegui la distillazione impatta enormemente. Usa dati diversificati e rappresentativi del task target. Dati out-of-distribution danneggiano il transfer.
- Alpha mal calibrato: con alpha=1 (solo KD) lo student ignora le etichette ground truth e può generare predizioni instabili se il teacher sbaglia. Con alpha=0 si riduce a training standard. Valori 0.5-0.8 sono tipicamente ottimali.
- Non congelare il teacher: il teacher deve essere in modalità eval() durante il training dello student. Se il teacher continua a cambiare (es. ha BatchNorm in train mode), i targets di distillazione sono inconsistenti e il training diverge.
Distillazione per LLM: Speculative Decoding e Response Distillation
Nel contesto dei Large Language Models, la distillazione assume forme nuove e potenti. Due tecniche particolarmente rilevanti nel 2025-2026 sono la Response Distillation (usata per addestrare Llama-3.2 e i modelli Phi di Microsoft) e lo Speculative Decoding, che sfrutta un modello piccolo per accelerare l'inferenza del modello grande senza perdita di qualità.
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
# ============================================================
# SPECULATIVE DECODING: Draft Model + Target Model
# ============================================================
# Principio: un modello piccolo (draft) genera K token in anticipo.
# Il modello grande (target) verifica tutti i K token in un solo forward pass.
# Se il draft ha ragione, si risparmiano K-1 forward pass del modello grande.
# Speedup tipico: 2-4x senza perdita di qualità.
class SpeculativeDecoder:
"""
Implementazione base di speculative decoding.
Draft model: modello piccolo (es. Llama-3.2-1B)
Target model: modello grande (es. Llama-3.1-8B)
Riferimento: "Fast Inference from Transformers via Speculative Decoding"
(Leviathan et al., 2022) - il paper originale di Google.
"""
def __init__(
self,
draft_model_name: str,
target_model_name: str,
device: str = "cuda",
lookahead_k: int = 5 # Token generati dal draft per ogni step
):
self.device = device
self.lookahead_k = lookahead_k
print(f"Carico draft model: {draft_model_name}...")
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name, torch_dtype=torch.float16
).to(device).eval()
print(f"Carico target model: {target_model_name}...")
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, torch_dtype=torch.float16
).to(device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
def draft_generate(self, input_ids: torch.Tensor) -> tuple:
"""
Il draft model genera K token e restituisce
le distribuzioni di probabilità per acceptance/rejection.
"""
draft_ids = input_ids.clone()
draft_probs = []
with torch.no_grad():
for _ in range(self.lookahead_k):
out = self.draft_model(draft_ids)
next_logits = out.logits[:, -1, :]
next_probs = F.softmax(next_logits, dim=-1)
# Campiona dal draft
next_token = torch.multinomial(next_probs, num_samples=1)
draft_probs.append(next_probs)
draft_ids = torch.cat([draft_ids, next_token], dim=1)
# draft_ids ora include K token aggiuntivi
draft_tokens = draft_ids[:, input_ids.shape[1]:]
return draft_tokens, draft_probs
def speculative_generate(
self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.7
) -> str:
"""Genera testo con speculative decoding."""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
generated = input_ids.clone()
total_accepted = 0
total_draft = 0
with torch.no_grad():
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# 1. Draft genera K token
draft_tokens, draft_probs = self.draft_generate(generated)
total_draft += self.lookahead_k
# 2. Target verifica tutti i K+1 token in un forward pass
full_seq = torch.cat([generated, draft_tokens], dim=1)
target_out = self.target_model(full_seq)
target_logits = target_out.logits[:, generated.shape[1]-1:-1, :]
# 3. Acceptance-rejection sampling
n_accepted = 0
for i in range(draft_tokens.shape[1]):
draft_tok = draft_tokens[0, i].item()
target_probs = F.softmax(target_logits[0, i] / temperature, dim=-1)
draft_p = draft_probs[i][0, draft_tok].item()
target_p = target_probs[draft_tok].item()
# Accetta se target d'accordo con draft
r = torch.rand(1).item()
if r < min(1.0, target_p / (draft_p + 1e-10)):
n_accepted += 1
else:
# Rifiuta: campiona dal target corretto
corrected = torch.multinomial(
F.relu(target_probs - draft_probs[i][0]),
num_samples=1
)
generated = torch.cat([
generated,
draft_tokens[:, :i],
corrected.unsqueeze(0)
], dim=1)
break
else:
# Tutti accettati: aggiungi bonus token dal target
generated = torch.cat([generated, draft_tokens], dim=1)
bonus_logits = target_out.logits[:, -1, :]
bonus_tok = torch.multinomial(
F.softmax(bonus_logits / temperature, dim=-1), 1
)
generated = torch.cat([generated, bonus_tok], dim=1)
total_accepted += n_accepted
acceptance_rate = total_accepted / max(total_draft, 1)
print(f"Acceptance rate: {acceptance_rate:.1%} (atteso 60-80% con draft simile)")
new_tokens = generated[0, input_ids.shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# ============================================================
# RESPONSE DISTILLATION per LLM (semplificata)
# ============================================================
# Tecnica usata da Llama-3.2, Phi-3, Mistral-7B-Instruct:
# 1. Teacher LLM grande (es. GPT-4, Llama-3.1-70B) genera risposte
# 2. Student LLM piccolo impara a imitare quelle risposte
# Diverso dalla distillazione classica: distilla risposte (output testo),
# non distribuzioni di probabilità interne.
def response_distillation_dataset(
teacher_model_name: str,
prompts: list,
output_file: str = "distillation_dataset.jsonl"
) -> list:
"""
Genera dataset di distillazione con risposte del teacher.
In produzione: usa GPT-4 API, Llama-3.1-70B, o Claude.
"""
import json
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
dataset = []
with open(output_file, 'w') as f:
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(teacher.device)
with torch.no_grad():
outputs = teacher.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
response_ids = outputs[0, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
entry = {"prompt": prompt, "response": response}
dataset.append(entry)
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"Dataset generato: {len(dataset)} esempi -> {output_file}")
return dataset
# Note: in pratica, usa l'API di un servizio commerciale (OpenAI, Anthropic)
# per generare le risposte del "teacher", poi addestra lo student su di esse.
# Questa e la tecnica dietro la maggior parte dei modelli instruction-following
# come Alpaca, Vicuna, Orca, e i modelli Phi di Microsoft.
print("LLM distillation patterns pronti")
Distillazione: Confronto Tecnico tra Varianti (2024-2025)
| Variante | Cosa Distilla | Pro | Contro | Uso Tipico |
|---|---|---|---|---|
| Soft Labels (Hinton 2015) | Distribuzioni di probabilità | Informazione ricca, standard | Richiede accesso ai logit del teacher | Vision, classificazione |
| Feature Distillation | Rappresentazioni intermedie | Transfer di feature profonde | Teacher e student devono avere architettura compatibile | Detection, segmentation |
| Response Distillation | Output testuale del teacher | Non richiede accesso interno | Perde informazione sulla incertezza | LLM instruction-following |
| Born-Again Networks | Self-distillation iterativa | Non richiede teacher separato | Guadagno limitato, costo computazionale alto | Ensemble, miglioramento |
| Speculative Decoding | Non e distillazione ma usa draft | 2-4x speedup senza perdita | Richiede due modelli in memoria | LLM inference acceleration |
Conclusioni
La Knowledge Distillation e una delle tecniche di compressione più potenti e versatili disponibili nel 2026. Combina naturalmente con quantizzazione e pruning: distilla prima per creare lo student ottimale, poi quantizza lo student per il deployment edge. Il risultato e spesso un modello 10-100x più piccolo del teacher con solo il 5-15% di perdita di accuratezza.
Per i LLM, la distillazione ha abilitato l'intera famiglia dei modelli "Distil*": DistilBERT, DistilGPT2, e i modelli Phi di Microsoft (2.7B con qualità di modelli 7B). Il trend del 2026 — Small Language Models (SLMs) che surclassano i LLM per frequenza di utilizzo secondo Gartner — e reso possibile proprio dalla distillazione, che trasferisce la conoscenza dei giganti a modelli che girano su Raspberry Pi e smartphone.
L'articolo successivo mostra come deployare questi modelli compressi sugli edge devices: Raspberry Pi, NVIDIA Jetson e hardware embedded, con tutte le ottimizzazioni necessarie per l'ambiente produttivo reale.
Prossimi Passi
- Articolo successivo: Deep Learning su Edge Devices: Da Cloud a Edge
- Correlato: Quantizzazione dei Modelli: GPTQ, AWQ, INT8
- Correlato: Pruning Reti Neurali: Ridurre Parametri
- Correlato: Vision Transformer: distillazione con DeiT
- Serie MLOps: Serving di Modelli Compressi in Produzione







