Distilarea cunoștințelor: comprimarea modelelor complexe
GPT-4 este prea mare pentru a rula pe laptop. Un ResNet-152 este prea lent pentru al tău dispozitiv mobil. Cu toate acestea, cunoștințele dobândite din aceste modele gigantice prin săptămâni de instruire pe clustere GPU poate fi transferat la modele mult mai mic și mai rapid, cu o pierdere surprinzător de modestă de precizie. Aceasta este promisiunea de Distilarea cunoștințelor.
Propusă de Hinton, Vinyals și Dean în 2015, distilarea a devenit una dintre tehnici cel mai puternic și versatil din setul de instrumente modern de deep learning. Principiul este elegant: un model mare (profesor) ghidează antrenamentul unui model mai mic (student) nu doar cu etichetele dure (0/1), ci cu etichetele probabilități blânde ale profesorului — distribuții de probabilitate bogate în informații despre asemănările dintre clase. Această informație implicită — „cunoașterea întunecată” — este ceea ce face distilarea atât de eficient.
În acest ghid explorăm în profunzime distilarea: de la teoria originală la variații modern (distilare caracteristică, transfer de atenție, auto-distilare, distilare LLM), cu Implementări complete în PyTorch, cele mai bune practici pentru producție și un studiu de caz real.
Ce vei învăța
- Teoria distilației: etichete moi, temperatură și cunoștințe întunecate
- Implementarea distilării standard cu PyTorch
- Distilare caracteristică: Transferarea reprezentărilor intermediare
- Transfer de atenție: Distilarea hărților de atenție ale Transformers
- Auto-distilare și rețele Born Again
- Offline vs online vs autodistilare
- Distilare pentru LLM: de la modele mari la modele de margine
- Combinați distilarea cu cuantizarea și tăierea
- Cele mai bune practici, erori comune și metrici de evaluare
- Studiu de caz: DistilBERT pas cu pas
Principiul distilarii: Cunoasterea intunecata
Un clasificator standard instruit pe etichete dure (one-hot) folosește foarte puține informații: clasa corectă are probabilitatea 1, toate celelalte 0. Dar un model bine pregătit știe multe mai mult: știe că o pisică seamănă mai mult cu un câine decât cu o mașină. Această informație este cuprinse în distribuția de probabilitate a profesorului, chiar și atunci când răspunsul este corect are probabilitate aproape 1.
Trucul la distilare este să folosești unul temperatura T a "inmoaie" probabilitățile profesorului, amplificând diferențele dintre orele improbabile, dar informative. Cu T=1 avem distribuția originală; cu T mare (de exemplu T=4) probabilitățile devin mai mari uniforme, dezvăluind relații de similitudine implicite. Acest mecanism se numește cunoștințe întunecate — cunoștințele ascunse în logit-urile modelului care a eticheta binară simplă nu poate captura.
# Visualizzazione effetto della temperatura sulla dark knowledge
import torch
import torch.nn.functional as F
import numpy as np
# Supponiamo che il teacher produca questi logits per un campione
# con classe vera = 0 (gatto)
teacher_logits = torch.tensor([8.2, 2.1, 1.8, 0.5, 0.3, -0.2, -0.5, -0.8, -1.1, -1.5])
# Classi ipotetiche: [gatto, cane, felino, leone, volpe, auto, aereo, nave, treno, barca]
classi = ["gatto", "cane", "felino", "leone", "volpe", "auto", "aereo", "nave", "treno", "barca"]
print("Effetto della temperatura sulle soft probabilities:")
print("-" * 75)
for T in [1, 2, 4, 8, 20]:
probs = F.softmax(teacher_logits / T, dim=0)
entropia = -(probs * probs.log()).sum().item()
print(f"T={T:2d}: p(gatto)={probs[0]:.4f}, "
f"p(cane)={probs[1]:.4f}, p(felino)={probs[2]:.4f}, "
f"entropia={entropia:.3f}")
# Output:
# T= 1: p(gatto)=0.9833, p(cane)=0.0106, p(felino)=0.0079, entropia=0.123
# T= 2: p(gatto)=0.9175, p(cane)=0.0440, p(felino)=0.0297, entropia=0.424
# T= 4: p(gatto)=0.7562, p(cane)=0.1168, p(felino)=0.0913, entropia=0.895
# T= 8: p(gatto)=0.5756, p(cane)=0.1572, p(felino)=0.1343, entropia=1.387
# T=20: p(gatto)=0.3520, p(cane)=0.1668, p(felino)=0.1600, entropia=1.944
#
# Con T alta, il teacher rivela che "gatto" e molto simile a "cane" e "felino"
# ma niente a che fare con "auto" o "aereo".
# Questa e la DARK KNOWLEDGE che lo student impara!
print("\nAnalisi della dark knowledge:")
probs_t1 = F.softmax(teacher_logits / 1, dim=0)
probs_t4 = F.softmax(teacher_logits / 4, dim=0)
for i, classe in enumerate(classi):
print(f" {classe:8s}: T=1: {probs_t1[i]:.4f}, T=4: {probs_t4[i]:.4f}")
Matematica distilarii
Pierderea prin distilare combină doi termeni cu un hiperparametru de echilibrare alfa:
- L_distill: Divergența KL între probabilitățile soft ale elevului și profesorului (la temperatura T), înmulțită cu T² pentru a compensa reducerea dimensiunii gradientului
- L_student: Entropie încrucișată standard între predicțiile elevilor și etichetele dure
Formula finală este: L = alfa * T² * KL(student_soft || profesor_soft) + (1-alpha) * CE(student, etichete)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
# ============================================================
# DISTILLATION LOSS - Implementazione completa
# ============================================================
class DistillationLoss(nn.Module):
"""
Loss per Knowledge Distillation (Hinton et al., 2015).
L = alpha * T^2 * KL(student_soft || teacher_soft) + (1-alpha) * CE(student, labels)
Il fattore T^2 e fondamentale: quando si usa T > 1,
i gradienti si riducono di 1/T^2. Moltiplicando per T^2
si compensa questo effetto, mantenendo scale dei gradienti
coerenti tra la KD loss e la CE loss.
"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
"""
temperature: scala le probabilità soft (tipico: 2-8)
alpha: peso della distillation loss (tipico: 0.5-0.9)
"""
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor) -> dict:
# Soft probabilities a temperatura T
# NB: log_softmax per student (richiesto da KLDivLoss)
student_soft = F.log_softmax(student_logits / self.T, dim=1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=1)
# KL divergence * T^2 per compensare la riduzione del gradiente
loss_distill = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
# Cross-entropy standard con hard labels
loss_student = self.ce_loss(student_logits, labels)
# Combinazione pesata
total_loss = self.alpha * loss_distill + (1 - self.alpha) * loss_student
return {
'total': total_loss,
'distill': loss_distill.detach(),
'student': loss_student.detach()
}
# ============================================================
# MODELLI TEACHER e STUDENT
# ============================================================
def create_teacher_student(n_classes: int = 100):
"""
Teacher: ResNet-50 (~25M parametri) - pre-trained su ImageNet
Student: MobileNetV3-Small (~2.5M parametri) - 10x più piccolo
"""
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, n_classes)
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, n_classes
)
total_teacher = sum(p.numel() for p in teacher.parameters())
total_student = sum(p.numel() for p in student.parameters())
flops_teacher = 4.1e9 # Approx FLOPs ResNet-50
flops_student = 0.056e9 # Approx FLOPs MobileNetV3-S
print(f"Teacher (ResNet-50): {total_teacher:,} param, {flops_teacher/1e9:.1f}G FLOPs")
print(f"Student (MobileNetV3): {total_student:,} param, {flops_student*1000:.0f}M FLOPs")
print(f"Fattore compressione: {total_teacher/total_student:.1f}x param, "
f"{flops_teacher/flops_student:.0f}x FLOPs")
return teacher, student
# ============================================================
# TRAINING LOOP CON DISTILLAZIONE
# ============================================================
def train_with_distillation(
teacher: nn.Module,
student: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 50,
temperature: float = 4.0,
alpha: float = 0.7,
lr: float = 1e-3,
device: str = "cuda"
):
teacher = teacher.to(device).eval() # Teacher: SOLO inference, no backprop!
student = student.to(device)
criterion = DistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.AdamW(student.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
best_acc = 0.0
history = {'train_loss': [], 'val_acc': [], 'distill_loss': [], 'student_loss': []}
for epoch in range(n_epochs):
student.train()
total_loss = distill_loss_sum = student_loss_sum = 0.0
n_batches = 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass teacher (CRITICO: no gradients, risparmia memoria!)
with torch.no_grad():
teacher_logits = teacher(imgs)
# Forward pass student (con gradients)
student_logits = student(imgs)
# Loss combinata
losses = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
losses['total'].backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
optimizer.step()
total_loss += losses['total'].item()
distill_loss_sum += losses['distill'].item()
student_loss_sum += losses['student'].item()
n_batches += 1
scheduler.step()
# Validation
student.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = student(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_total = total_loss / n_batches
history['train_loss'].append(avg_total)
history['val_acc'].append(val_acc)
history['distill_loss'].append(distill_loss_sum / n_batches)
history['student_loss'].append(student_loss_sum / n_batches)
if val_acc > best_acc:
best_acc = val_acc
torch.save(student.state_dict(), "best_student.pth")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {avg_total:.4f} "
f"| Val Acc: {val_acc:.4f} | Best: {best_acc:.4f}")
print(f"\nMiglior accuracy student: {best_acc:.4f}")
return history, best_acc
# Risultati tipici CIFAR-100:
# ResNet-50 teacher: 78.2% Top-1
# MobileNetV3-S senza KD: 67.1% Top-1
# MobileNetV3-S con KD: 71.4% Top-1 (+4.3%)
# MobileNetV3-S con KD+feat: 73.2% Top-1 (+6.1%)
# Compression: 10x param, 73x FLOPs
Distilarea caracteristicilor: transferul reprezentărilor interne
Distilarea pe etichete moi transferă doar rezultatul final al profesorului. Acolo Caracteristică Distilarea merge mai departe: îl obligă pe elev să reproducă și el reprezentări intermediare ale profesorului — hărțile caracteristicilor la diferite niveluri ale rețelei. Acest lucru este deosebit de eficient atunci când profesorii și studenții au arhitecturi foarte diferite (de exemplu, profesor CNN, student ViT) și când sarcina necesită caracteristici spațiale bogate.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# ============================================================
# FEATURE EXTRACTOR via Forward Hooks
# ============================================================
class FeatureExtractor:
"""Cattura le feature di layer specifici tramite forward hooks."""
def __init__(self, model: nn.Module, layer_names: list):
self.features = {}
self.hooks = []
for name, module in model.named_modules():
if name in layer_names:
hook = module.register_forward_hook(
lambda m, inp, out, n=name: self.features.update({n: out})
)
self.hooks.append(hook)
def get_features(self) -> list:
return list(self.features.values())
def clear(self):
self.features.clear()
def remove(self):
"""Rimuovi hooks per evitare memory leaks."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
# ============================================================
# FEATURE DISTILLATION LOSS
# ============================================================
class FeatureDistillationLoss(nn.Module):
"""
Loss che combina:
1. KD loss standard (soft labels output)
2. Feature Matching Loss (MSE tra feature intermedie normalizzate)
3. Relation-Based Loss (distanze relative tra sample nel batch)
"""
def __init__(self, student_channels: list, teacher_channels: list,
temperature: float = 4.0, alpha: float = 0.4,
beta: float = 0.4, gamma: float = 0.2):
"""
alpha: peso KD loss
beta: peso feature matching loss
gamma: peso CE loss standard
(alpha + beta + gamma deve essere 1.0)
"""
super().__init__()
assert abs(alpha + beta + gamma - 1.0) < 1e-6, "Pesi devono sommare a 1"
self.T = temperature
self.alpha = alpha
self.beta = beta
self.gamma = gamma
# Adattatori per allineare dimensioni teacher->student
# Esempio: teacher ha 2048 canali, student 96 -> adapter 1x1 conv
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Conv2d(t_ch, s_ch, 1, bias=False),
nn.BatchNorm2d(s_ch),
nn.ReLU(inplace=True)
)
for t_ch, s_ch in zip(teacher_channels, student_channels)
])
def forward(self, student_logits, teacher_logits, labels,
student_features: list, teacher_features: list):
# 1. KD Loss (soft labels)
kl = nn.KLDivLoss(reduction='batchmean')
loss_kd = kl(
F.log_softmax(student_logits / self.T, dim=1),
F.softmax(teacher_logits / self.T, dim=1)
) * self.T ** 2
# 2. CE Loss standard
loss_ce = F.cross_entropy(student_logits, labels)
# 3. Feature Matching Loss
loss_feat = torch.tensor(0.0, device=student_logits.device)
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
# Adatta canali del teacher allo student
t_adapted = self.adapters[i](t_feat.detach())
# Allinea risoluzione spaziale se necessario
if s_feat.shape[2:] != t_adapted.shape[2:]:
t_adapted = F.interpolate(
t_adapted, size=s_feat.shape[2:],
mode='bilinear', align_corners=False
)
# Normalizza le feature (cosine similarity invece di MSE)
s_norm = F.normalize(s_feat.view(s_feat.size(0), -1), dim=1)
t_norm = F.normalize(t_adapted.view(t_adapted.size(0), -1), dim=1)
# MSE tra feature normalizzate
loss_feat = loss_feat + F.mse_loss(s_norm, t_norm)
loss_feat = loss_feat / max(len(student_features), 1)
total = self.alpha * loss_kd + self.beta * loss_feat + self.gamma * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'feat': loss_feat.detach()
}
# Configurazione per ResNet-50 teacher -> MobileNetV3-S student
# Layer teacher: [layer2, layer3, layer4] -> Canali: [512, 1024, 2048]
# Layer student: [features.4, features.9, features.12] -> Canali: [40, 96, 576]
teacher = models.resnet50(pretrained=True)
student = models.mobilenet_v3_small(pretrained=False)
teacher_layers = ['layer2', 'layer3', 'layer4']
student_layers = ['features.4', 'features.9', 'features.12']
teacher_channels = [512, 1024, 2048]
student_channels = [40, 96, 576]
teacher_extractor = FeatureExtractor(teacher, teacher_layers)
student_extractor = FeatureExtractor(student, student_layers)
feat_criterion = FeatureDistillationLoss(
student_channels=student_channels,
teacher_channels=teacher_channels,
temperature=4.0, alpha=0.4, beta=0.4, gamma=0.2
)
print("Feature Distillation setup completato!")
print(f"Teacher layers: {teacher_layers}")
print(f"Student layers: {student_layers}")
Transfer de atenție pentru Transformer și Vision Transformer
Vision Transformers produc hărți explicite de atenție care pot fi direct distilat. DeiT (Data-efficient Image Transformer) utilizează această abordare cu un jeton special de distilare. THE'Transfer de atenție (Zagoruyko & Komodakis, 2017) extinde conceptul și la CNN-uri, construind hărți de atenție din activările straturilor convoluţionale.
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# ============================================================
# ATTENTION TRANSFER (AT) per CNN
# ============================================================
class AttentionTransferLoss(nn.Module):
"""
Attention Transfer (Zagoruyko & Komodakis, 2017).
Forza lo student a replicare le attention maps del teacher.
Efficace per transfer tra architetture diverse (CNN <-> ViT).
"""
def __init__(self, beta: float = 1000.0):
super().__init__()
self.beta = beta
def attention_map(self, features: torch.Tensor) -> torch.Tensor:
"""
Calcola mappa di attention come norma L2 quadrata delle attivazioni.
Input features: [B, C, H, W]
Output: [B, H*W] normalizzato (attention map piatta)
"""
# Somma sui canali -> [B, H, W]
attention = features.pow(2).sum(dim=1)
# Appiattisci -> [B, H*W]
attention = attention.view(attention.size(0), -1)
# Normalizza L2 per ogni sample nel batch
return F.normalize(attention, p=2, dim=1)
def forward(self, student_features: list, teacher_features: list) -> torch.Tensor:
"""Calcola AT loss su più livelli."""
total_loss = torch.tensor(0.0)
for s_feat, t_feat in zip(student_features, teacher_features):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat).detach()
# Allinea dimensioni spaziali se necessario
if s_attn.shape != t_attn.shape:
s_h = int(s_feat.shape[2] * s_feat.shape[3])
t_h = int(t_feat.shape[2] * t_feat.shape[3])
# Usa interpolazione sull'attention map 2D
s_2d = s_feat.pow(2).mean(1, keepdim=True)
t_2d = t_feat.pow(2).mean(1, keepdim=True)
t_2d = F.interpolate(t_2d, size=s_feat.shape[2:], mode='bilinear')
s_attn = F.normalize(s_2d.view(s_2d.size(0), -1), p=2, dim=1)
t_attn = F.normalize(t_2d.view(t_2d.size(0), -1), p=2, dim=1).detach()
total_loss = total_loss + (s_attn - t_attn).pow(2).mean()
return self.beta * total_loss / max(len(student_features), 1)
# ============================================================
# DeiT-STYLE: Distillation Token per Vision Transformer
# ============================================================
class ViTWithDistillationToken(nn.Module):
"""
Aggiunge un distillation token a un ViT standard.
Come in DeiT: il token impara a replicare le predizioni
di un teacher CNN (es. RegNet, ResNet).
Durante inference: media CLS token + dist token.
Durante training: loss su entrambi i token.
"""
def __init__(self, vit_model: nn.Module, n_classes: int, d_model: int = 384):
super().__init__()
self.vit = vit_model
# Token di distillazione learnable
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.dist_token, std=0.02)
# Distillation head (separato dal CLS head)
self.dist_head = nn.Linear(d_model, n_classes)
def forward(self, x: torch.Tensor, return_dist: bool = False):
# Ottieni feature dal ViT
features = self.vit.forward_features(x)
# CLS prediction (predizione principale)
cls_pred = self.vit.head(features[:, 0])
# Dist token prediction (guida del teacher)
dist_pred = self.dist_head(features[:, 1]) # Assumendo dist_token al pos 1
if self.training:
return cls_pred, dist_pred
else:
# Inference: media delle due predizioni
return (cls_pred + dist_pred) / 2.0
def deit_distillation_loss(cls_pred, dist_pred, teacher_pred, labels,
alpha: float = 0.5, temperature: float = 3.0):
"""
Loss DeiT: combina CE hard labels + KD dal teacher CNN.
"""
# Hard label loss sul CLS token
loss_cls = F.cross_entropy(cls_pred, labels)
# Soft label loss sul distillation token
loss_dist = F.kl_div(
F.log_softmax(dist_pred / temperature, dim=1),
F.softmax(teacher_pred / temperature, dim=1),
reduction='batchmean'
) * temperature ** 2
return alpha * loss_dist + (1 - alpha) * loss_cls
Distilare pentru LLM: de la modele mari la modele mici
Distilarea pentru LLM urmează aceleași principii, dar cu unele specificități importante. Vocabularul este imens (32K-128K jetoane), modelele profesorului și elevilor trebuie să fie compatibil la nivel de tokenizer, iar pierderea operează la nivelul fiecărui jeton din secvență. DistilBERT, DistilGPT2 și familia Phi a Microsoft sunt exemple de succes.
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
AutoModelForSequenceClassification
)
import torch
import torch.nn.functional as F
# ============================================================
# DISTILLAZIONE LLM: causal language modeling
# ============================================================
def distill_llm_batch(
teacher_model,
student_model,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 2.0,
alpha: float = 0.7
) -> dict:
"""
Distillazione LLM per next-token prediction.
Funziona per GPT-style (causal) e BERT-style (masked).
teacher_model: modello grande (es. Llama-3-8B)
student_model: modello piccolo (es. Llama-3-1B)
alpha: peso KD loss (1-alpha = peso CE loss standard)
"""
device = next(student_model.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Teacher inference (no gradients, può essere su device diverso)
with torch.no_grad():
teacher_outputs = teacher_model(
input_ids, attention_mask=attention_mask
)
teacher_logits = teacher_outputs.logits # [B, seq_len, vocab_size]
# Student inference (con gradients)
student_outputs = student_model(
input_ids, attention_mask=attention_mask
)
student_logits = student_outputs.logits
# Shift per next-token prediction (esclude l'ultimo token come input)
shift_student = student_logits[:, :-1, :].contiguous()
shift_teacher = teacher_logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# Ridimensiona per calcolo per-token loss
B, S, V = shift_student.shape
shift_student_flat = shift_student.view(B * S, V)
shift_teacher_flat = shift_teacher.view(B * S, V)
shift_labels_flat = shift_labels.view(B * S)
# 1. KD Loss: KL divergence per ogni token
student_log_probs = F.log_softmax(shift_student_flat / temperature, dim=-1)
teacher_probs = F.softmax(shift_teacher_flat / temperature, dim=-1)
loss_kd = F.kl_div(student_log_probs, teacher_probs,
reduction='batchmean') * temperature ** 2
# 2. CE Loss standard (con label -100 per token da ignorare)
loss_ce = F.cross_entropy(
shift_student_flat, shift_labels_flat,
ignore_index=-100 # Padding tokens
)
total = alpha * loss_kd + (1 - alpha) * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'perplexity': torch.exp(loss_ce).detach()
}
# ============================================================
# PIPELINE DISTILLAZIONE LLM COMPLETA
# ============================================================
def setup_llm_distillation(
teacher_name: str = "meta-llama/Llama-3.1-8B",
student_name: str = "meta-llama/Llama-3.2-1B",
device: str = "cuda"
):
"""
Setup per distillare un LLM grande in uno più piccolo.
IMPORTANTE: teacher e student devono condividere il tokenizer
per avere distribuzioni compatibili sullo stesso vocabolario.
"""
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Teacher: carica in FP16 per risparmiare memoria
teacher = AutoModelForCausalLM.from_pretrained(
teacher_name,
torch_dtype=torch.float16,
device_map="auto" # Distribuisce su più GPU se disponibili
)
teacher.eval()
# Student: carica in FP32 per training stabile
student = AutoModelForCausalLM.from_pretrained(
student_name,
torch_dtype=torch.float32
).to(device)
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {teacher_params/1e9:.1f}B parametri")
print(f"Student: {student_params/1e9:.1f}B parametri")
print(f"Compressione: {teacher_params/student_params:.1f}x")
return teacher, student, tokenizer
print("Setup distillazione LLM completato!")
Auto-distilare și rețele Born Again
La auto-distilarea și o variantă surprinzătoare: modelul acționează ca profesor la sine. În Rețelele Born Again (BANs, Furlanello et al. 2018), they train generații succesive de modele cu aceeași arhitectură: fiecare generație folosește anterior ca profesor. Rezultatul este o îmbunătățire sistematică (+1-2% Top-1) fără increase the size of the model.
import torch
import torch.nn as nn
import copy
# ============================================================
# BORN AGAIN NETWORKS (BANs)
# ============================================================
def born_again_training(model_factory, train_loader, val_loader,
n_generations: int = 3,
temperature: float = 4.0,
n_epochs: int = 30,
device: str = "cuda"):
"""
Allena N generazioni con la stessa architettura.
Gen 1: training standard con CE loss.
Gen 2+: distillazione dalla generazione precedente.
Risultati tipici CIFAR-100:
Gen 1: 67.1% (baseline)
Gen 2: 70.8% (+3.7%)
Gen 3: 72.1% (+5.0%)
Ensemble 1+2+3: 74.8% (+7.7%)
"""
criterion_kd = DistillationLoss(temperature=temperature, alpha=0.7)
criterion_ce = nn.CrossEntropyLoss()
all_models = []
results = []
# === Generazione 1: training standard ===
print("Gen 1: training standard...")
gen1 = model_factory().to(device)
opt1 = torch.optim.SGD(gen1.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch1 = torch.optim.lr_scheduler.MultiStepLR(opt1, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
gen1.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
opt1.zero_grad()
criterion_ce(gen1(imgs), labels).backward()
opt1.step()
sch1.step()
acc1 = _evaluate(gen1, val_loader, device)
results.append(acc1)
all_models.append(gen1)
print(f" Gen 1 val acc: {acc1:.4f}")
teacher = copy.deepcopy(gen1).eval()
# === Generazioni successive con distillazione ===
for gen_idx in range(2, n_generations + 1):
print(f"Gen {gen_idx}: KD da gen {gen_idx-1}...")
student = model_factory().to(device)
opt = torch.optim.SGD(student.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
student.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
with torch.no_grad():
t_logits = teacher(imgs)
s_logits = student(imgs)
losses = criterion_kd(s_logits, t_logits, labels)
opt.zero_grad()
losses['total'].backward()
opt.step()
sch.step()
acc = _evaluate(student, val_loader, device)
results.append(acc)
all_models.append(student)
print(f" Gen {gen_idx} val acc: {acc:.4f}")
teacher = copy.deepcopy(student).eval()
# Ensemble di tutti i modelli (upper bound)
ensemble_acc = _ensemble_evaluate(all_models, val_loader, device)
print(f"\nEnsemble {n_generations} gen: {ensemble_acc:.4f}")
return results, all_models
def _evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
correct += (model(imgs).argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
def _ensemble_evaluate(models, loader, device):
"""Ensemble averaging delle predizioni."""
for m in models:
m.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
logits_sum = torch.stack([m(imgs) for m in models]).mean(0)
correct += (logits_sum.argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
Rezultate tipice de distilare (Etalon de referință 2024-2025)
| Sarcini | Profesor | Student | Fara KD | Cu KD | Profesor | Comprimare |
|---|---|---|---|---|---|---|
| CIFAR-100 | ResNet-50 | MobileNetV3-S | 67,1% | 73,2% | 78,2% | 10x param |
| ImageNet | ViT-L/16 | DeiT-S | 79,8% | 83,1% | 87,1% | 5x param |
| Adeziv (NLP) | BERT-Mare | DistilBERT | 83,2% | 86,4% | 89,2% | 2x param, 2x viteza |
| SQUAD (QA) | ROBERTa-L | DistilRoBERTa | 82,1% | 85,8% | 90,4% | 2x param |
| LLM (perplexitate) | Lama 3.1 8B | Lama 3.2 1B | 8,24 PPL | 7,81 PPL | 6.12 PPL | 8x param |
KD recuperează de obicei 70-85% din decalajul dintre elev și profesor cu de 2-10 ori mai puțini parametri.
Conducta de producție: distilare + cuantificare
Cel mai puternic flux de lucru pentru implementarea edge combină distilarea și cuantizarea în secvență: creați mai întâi elevul cu KD (menține precizia ridicată), apoi cuantificați elevul (reduce dimensiunea și crește viteza). Combinația poate reduce un model de la 100x la 40x în comparație cu profesorul original cu doar 5-10% pierderi de precizie.
import torch
import torch.nn as nn
from torchvision import models
# ============================================================
# PIPELINE COMPLETA: Distillazione -> Quantizzazione -> ONNX
# ============================================================
def full_compression_pipeline(teacher_path: str, output_dir: str = "./compressed"):
"""
Pipeline completa per comprimere un modello per edge deployment.
Step 1: Carica teacher pre-trainato
Step 2: Distilla in student più piccolo
Step 3: Quantizza lo student (PTQ INT8)
Step 4: Esporta in ONNX per deployment cross-platform
"""
import os
os.makedirs(output_dir, exist_ok=True)
# STEP 1: Teacher
print("Step 1: Carico teacher...")
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(2048, 10) # 10 classi
# In produzione: teacher.load_state_dict(torch.load(teacher_path))
teacher.eval()
teacher_size_mb = sum(p.numel() * p.element_size()
for p in teacher.parameters()) / (1024**2)
print(f" Teacher: {teacher_size_mb:.1f} MB, "
f"{sum(p.numel() for p in teacher.parameters())/1e6:.1f}M param")
# STEP 2: Student (dopo distillazione)
print("Step 2: Student con distillazione (simulato con MobileNetV3)...")
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, 10
)
# In produzione: train_with_distillation(teacher, student, ...)
# student.load_state_dict(torch.load("best_student.pth"))
student_size_mb = sum(p.numel() * p.element_size()
for p in student.parameters()) / (1024**2)
print(f" Student: {student_size_mb:.1f} MB, "
f"{sum(p.numel() for p in student.parameters())/1e6:.1f}M param")
print(f" Riduzione rispetto teacher: {teacher_size_mb/student_size_mb:.1f}x")
# STEP 3: Quantizzazione INT8 (PTQ)
print("Step 3: Quantizzazione INT8...")
student.eval()
# Quantizzazione dinamica (più semplice, leggermente meno efficiente)
student_quant = torch.quantization.quantize_dynamic(
student,
{nn.Linear}, # Quantizza solo Linear (Conv2d richiede calibrazione)
dtype=torch.qint8
)
# Per quantizzazione statica (più efficiente):
# student.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# torch.quantization.prepare(student, inplace=True)
# calibrate(student, calib_loader) # Forward pass su dati di calibrazione
# torch.quantization.convert(student, inplace=True)
quant_size_mb = sum(p.numel() * p.element_size()
for p in student_quant.parameters()) / (1024**2)
print(f" Student INT8: ~{student_size_mb/4:.1f} MB (stima)")
print(f" Riduzione totale: ~{teacher_size_mb/(student_size_mb/4):.0f}x vs teacher")
# STEP 4: Export ONNX
print("Step 4: Export ONNX...")
dummy = torch.randn(1, 3, 224, 224)
onnx_path = f"{output_dir}/student_compressed.onnx"
torch.onnx.export(
student, # Usa FP32 per ONNX (compatibilità più ampia)
dummy,
onnx_path,
opset_version=13,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}},
export_params=True
)
print(f"\n=== RIEPILOGO PIPELINE ===")
print(f"Teacher (ResNet-50): {teacher_size_mb:.1f} MB")
print(f"Student KD (MobNetV3): {student_size_mb:.1f} MB ({teacher_size_mb/student_size_mb:.1f}x riduzione)")
print(f"Student INT8 (stimato): {student_size_mb/4:.1f} MB ({teacher_size_mb/(student_size_mb/4):.0f}x riduzione)")
print(f"ONNX salvato: {onnx_path}")
return student_quant, onnx_path
full_compression_pipeline("teacher_weights.pth")
Anti-modele în distilare: greșeli frecvente
- Temperaturi prea ridicate sau prea scăzute: T=1 este egal cu etichete dure. T prea mare (>20) face etichetele moi aproape uniforme, pierzând semnalul. Efectuați întotdeauna un studiu de ablație cu T ∈ {2, 4, 6, 8} pe setul dvs. de date specific.
- Profesor și elev prea diferiți ca abilități: dacă decalajul este mare (GPT-4 până la 7B), distilarea directă este ineficientă. Utilizați distilare în cascadă: GPT-4 -> 13B -> 7B -> 3B. Fiecare pas nu trebuie să depășească reducerea de 4-5 ori.
- Ignorarea calității setului de date de distilare: calitatea de setul de date pe care efectuați distilarea are un impact uriaș. Folosiți date diverse și reprezentant al sarcinii țintă. Datele din afara distribuției corupă transferul.
- Alfa prost calibrat: cu alpha=1 (numai KD) elevul ignoră etichetează adevărul de bază și poate genera predicții instabile dacă profesorul greșește. Cu alpha=0 se reduce la antrenamentul standard. Valorile 0,5-0,8 sunt de obicei optime.
- Nu înghețați profesorul: profesorul trebuie să fie în modul eval(). în timpul pregătirii elevului. Dacă profesorul continuă să se schimbe (de exemplu, are BatchNorm în modul tren), țintele de distilare sunt inconsecvente și antrenamentul diverge.
Distilarea pentru LLM: decodare speculativă și distilare cu răspuns
În contextul modelelor de limbaj mari, distilarea ia forme noi și puternice. Două tehnici care sunt deosebit de relevante în 2025-2026 sunt Distilarea răspunsului (folosit pentru a antrena modelele Llama-3.2 și Microsoft Phi) și iată Decodare speculativă, care folosește un model mic pentru a accelera inferența a modelului mare fără pierderi de calitate.
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
# ============================================================
# SPECULATIVE DECODING: Draft Model + Target Model
# ============================================================
# Principio: un modello piccolo (draft) genera K token in anticipo.
# Il modello grande (target) verifica tutti i K token in un solo forward pass.
# Se il draft ha ragione, si risparmiano K-1 forward pass del modello grande.
# Speedup tipico: 2-4x senza perdita di qualità.
class SpeculativeDecoder:
"""
Implementazione base di speculative decoding.
Draft model: modello piccolo (es. Llama-3.2-1B)
Target model: modello grande (es. Llama-3.1-8B)
Riferimento: "Fast Inference from Transformers via Speculative Decoding"
(Leviathan et al., 2022) - il paper originale di Google.
"""
def __init__(
self,
draft_model_name: str,
target_model_name: str,
device: str = "cuda",
lookahead_k: int = 5 # Token generati dal draft per ogni step
):
self.device = device
self.lookahead_k = lookahead_k
print(f"Carico draft model: {draft_model_name}...")
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name, torch_dtype=torch.float16
).to(device).eval()
print(f"Carico target model: {target_model_name}...")
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, torch_dtype=torch.float16
).to(device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
def draft_generate(self, input_ids: torch.Tensor) -> tuple:
"""
Il draft model genera K token e restituisce
le distribuzioni di probabilità per acceptance/rejection.
"""
draft_ids = input_ids.clone()
draft_probs = []
with torch.no_grad():
for _ in range(self.lookahead_k):
out = self.draft_model(draft_ids)
next_logits = out.logits[:, -1, :]
next_probs = F.softmax(next_logits, dim=-1)
# Campiona dal draft
next_token = torch.multinomial(next_probs, num_samples=1)
draft_probs.append(next_probs)
draft_ids = torch.cat([draft_ids, next_token], dim=1)
# draft_ids ora include K token aggiuntivi
draft_tokens = draft_ids[:, input_ids.shape[1]:]
return draft_tokens, draft_probs
def speculative_generate(
self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.7
) -> str:
"""Genera testo con speculative decoding."""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
generated = input_ids.clone()
total_accepted = 0
total_draft = 0
with torch.no_grad():
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# 1. Draft genera K token
draft_tokens, draft_probs = self.draft_generate(generated)
total_draft += self.lookahead_k
# 2. Target verifica tutti i K+1 token in un forward pass
full_seq = torch.cat([generated, draft_tokens], dim=1)
target_out = self.target_model(full_seq)
target_logits = target_out.logits[:, generated.shape[1]-1:-1, :]
# 3. Acceptance-rejection sampling
n_accepted = 0
for i in range(draft_tokens.shape[1]):
draft_tok = draft_tokens[0, i].item()
target_probs = F.softmax(target_logits[0, i] / temperature, dim=-1)
draft_p = draft_probs[i][0, draft_tok].item()
target_p = target_probs[draft_tok].item()
# Accetta se target d'accordo con draft
r = torch.rand(1).item()
if r < min(1.0, target_p / (draft_p + 1e-10)):
n_accepted += 1
else:
# Rifiuta: campiona dal target corretto
corrected = torch.multinomial(
F.relu(target_probs - draft_probs[i][0]),
num_samples=1
)
generated = torch.cat([
generated,
draft_tokens[:, :i],
corrected.unsqueeze(0)
], dim=1)
break
else:
# Tutti accettati: aggiungi bonus token dal target
generated = torch.cat([generated, draft_tokens], dim=1)
bonus_logits = target_out.logits[:, -1, :]
bonus_tok = torch.multinomial(
F.softmax(bonus_logits / temperature, dim=-1), 1
)
generated = torch.cat([generated, bonus_tok], dim=1)
total_accepted += n_accepted
acceptance_rate = total_accepted / max(total_draft, 1)
print(f"Acceptance rate: {acceptance_rate:.1%} (atteso 60-80% con draft simile)")
new_tokens = generated[0, input_ids.shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# ============================================================
# RESPONSE DISTILLATION per LLM (semplificata)
# ============================================================
# Tecnica usata da Llama-3.2, Phi-3, Mistral-7B-Instruct:
# 1. Teacher LLM grande (es. GPT-4, Llama-3.1-70B) genera risposte
# 2. Student LLM piccolo impara a imitare quelle risposte
# Diverso dalla distillazione classica: distilla risposte (output testo),
# non distribuzioni di probabilità interne.
def response_distillation_dataset(
teacher_model_name: str,
prompts: list,
output_file: str = "distillation_dataset.jsonl"
) -> list:
"""
Genera dataset di distillazione con risposte del teacher.
In produzione: usa GPT-4 API, Llama-3.1-70B, o Claude.
"""
import json
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
dataset = []
with open(output_file, 'w') as f:
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(teacher.device)
with torch.no_grad():
outputs = teacher.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
response_ids = outputs[0, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
entry = {"prompt": prompt, "response": response}
dataset.append(entry)
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"Dataset generato: {len(dataset)} esempi -> {output_file}")
return dataset
# Note: in pratica, usa l'API di un servizio commerciale (OpenAI, Anthropic)
# per generare le risposte del "teacher", poi addestra lo student su di esse.
# Questa e la tecnica dietro la maggior parte dei modelli instruction-following
# come Alpaca, Vicuna, Orca, e i modelli Phi di Microsoft.
print("LLM distillation patterns pronti")
Distilare: comparație tehnică între variante (2024-2025)
| Variantă | Ceea ce distilează | Pro | Împotriva | Utilizare tipică |
|---|---|---|---|---|
| Etichete soft (Hinton 2015) | Distribuții de probabilitate | Informații bogate, standard | Necesită acces la logit-urile profesorului | Viziune, clasificare |
| Caracteristică Distilarea | Reprezentări intermediare | Transfer profund de caracteristici | Profesorul și studentul trebuie să aibă o arhitectură compatibilă | Detectare, segmentare |
| Distilarea răspunsului | Rezultatul textual al profesorului | Nu necesită acces intern | Pierde informații despre incertitudine | Urmărirea instrucțiunilor LLM |
| Rețele Born-Again | Auto-distilarea iterativă | Nu necesită profesor separat | Câștig limitat, cost de calcul ridicat | Ansamblu, îmbunătățire |
| Decodare speculativă | Nu este distilare, dar folosește draft | Accelerare de 2-4 ori fără pierderi | Necesită două modele în memorie | Accelerarea inferenței LLM |
Concluzii
Distilarea prin cunoștințe este una dintre cele mai puternice și versatile tehnici de compresie disponibil în 2026. Combinați în mod natural cu cuantizarea și tăierea: mai întâi distilați pentru a crea studentul optim, apoi cuantificați studentul pentru implementarea marginii. Rezultatul și adesea un model de 10-100 de ori mai mic decât profesorul cu doar 5-15% pierderi de precizie.
Pentru LLM, distilarea a permis întreaga familie de modele „Distil*”: DistilBERT, DistilGPT2 și modelele Microsoft Phi (2.7B cu calitatea modelului 7B). Tendința anului 2026 — Modele lingvistice mici (SLM) care depășesc LLM-urile ca frecvență de utilizare conform Gartner — și făcut posibil tocmai prin distilare, care transferă cunoștințele despre giganți la modele care rulează pe Raspberry Pi și smartphone-uri.
Următorul articol arată cum să implementați aceste modele comprimate pe dispozitive de margine: Raspberry Pi, NVIDIA Jetson și hardware încorporat, cu toate optimizările necesare pentru mediul real de producție.
Următorii pași
- Articolul următor: Învățare profundă pe dispozitivele Edge: de la cloud la Edge
- Înrudit: Cuantificare model: GPTQ, AWQ, INT8
- Înrudit: Tunderea rețelelor neuronale: reducerea parametrilor
- Înrudit: Vision Transformer: distilare cu DeiT
- Seria MLOps: Servirea modelelor comprimate în producție







