Destylacja wiedzy: kompresja złożonych modeli
GPT-4 jest zbyt duży, aby uruchomić go na twoim laptopie. ResNet-152 jest dla Ciebie za wolny urządzenie mobilne. Jednak wiedza zdobyta z tych gigantycznych modeli poprzez tygodnie szkoleń na temat klastrów GPU przeniesiony do modeli znacznie mniejszy i szybszy, z zaskakująco niewielką utratą celności. To jest obietnica Destylacja wiedzy.
Destylacja zaproponowana przez Hintona, Vinyalsa i Deana w 2015 roku stała się jedną z technik najpotężniejszy i najbardziej wszechstronny w nowoczesnym zestawie narzędzi do głębokiego uczenia się. Zasada jest elegancka: duży model (nauczyciel) kieruje szkoleniem mniejszego modelu (student) nie tylko z twardymi etykietami (0/1), ale z miękkie prawdopodobieństwa nauczyciela — rozkłady prawdopodobieństwa bogate w informacje o podobieństwach między klasami. Ta ukryta informacja – „ciemna wiedza” – jest tym, co powoduje destylację tak skuteczny.
W tym przewodniku szczegółowo omawiamy destylację: od oryginalnej teorii po odmiany nowoczesny (cecha destylacji, przeniesienie uwagi, samodestylacja, destylacja LLM), z Kompletne wdrożenia w PyTorch, najlepsze praktyki produkcyjne i studium przypadku ze świata rzeczywistego.
Czego się nauczysz
- Teoria destylacji: miękkie etykiety, temperatura i wiedza ciemna
- Wdrażanie standardowej destylacji za pomocą PyTorch
- Funkcja Destylacja: Przenoszenie reprezentacji pośrednich
- Przeniesienie uwagi: Destylacja map uwagi Transformersów
- Sieci samodestylacji i Born Again
- Offline vs online vs samodestylacja
- Destylacja dla LLM: od dużych modeli po modele brzegowe
- Połącz destylację z kwantyzacją i przycinaniem
- Najlepsze praktyki, typowe błędy i wskaźniki oceny
- Studium przypadku: DistilBERT krok po kroku
Zasada destylacji: ciemna wiedza
Standardowy klasyfikator przeszkolony na twardych etykietach (one-hot) wykorzystuje bardzo mało informacji: właściwa klasa ma prawdopodobieństwo 1, wszystkie pozostałe 0. Ale dobrze wytrenowany model wie dużo więcej: wie, że kot bardziej przypomina psa niż samochód. Ta informacja jest zawarte w rozkładzie prawdopodobieństwa nauczyciela, nawet jeśli odpowiedź jest prawidłowa ma prawdopodobieństwo bliskie 1.
Sztuka destylacji polega na użyciu jednego temperatura T „zmiękczyć” prawdopodobieństwa nauczyciela, wzmacniając różnice między mało prawdopodobnymi, ale pouczającymi zajęciami. Przy T=1 mamy rozkład pierwotny; przy wysokim T (np. T = 4) prawdopodobieństwa stają się większe jednolite, ujawniając ukryte relacje podobieństwa. Mechanizm ten nazywa się ciemna wiedza — wiedza ukryta w logitach modelu, że a prosta etykieta binarna nie może zostać przechwycona.
# Visualizzazione effetto della temperatura sulla dark knowledge
import torch
import torch.nn.functional as F
import numpy as np
# Supponiamo che il teacher produca questi logits per un campione
# con classe vera = 0 (gatto)
teacher_logits = torch.tensor([8.2, 2.1, 1.8, 0.5, 0.3, -0.2, -0.5, -0.8, -1.1, -1.5])
# Classi ipotetiche: [gatto, cane, felino, leone, volpe, auto, aereo, nave, treno, barca]
classi = ["gatto", "cane", "felino", "leone", "volpe", "auto", "aereo", "nave", "treno", "barca"]
print("Effetto della temperatura sulle soft probabilities:")
print("-" * 75)
for T in [1, 2, 4, 8, 20]:
probs = F.softmax(teacher_logits / T, dim=0)
entropia = -(probs * probs.log()).sum().item()
print(f"T={T:2d}: p(gatto)={probs[0]:.4f}, "
f"p(cane)={probs[1]:.4f}, p(felino)={probs[2]:.4f}, "
f"entropia={entropia:.3f}")
# Output:
# T= 1: p(gatto)=0.9833, p(cane)=0.0106, p(felino)=0.0079, entropia=0.123
# T= 2: p(gatto)=0.9175, p(cane)=0.0440, p(felino)=0.0297, entropia=0.424
# T= 4: p(gatto)=0.7562, p(cane)=0.1168, p(felino)=0.0913, entropia=0.895
# T= 8: p(gatto)=0.5756, p(cane)=0.1572, p(felino)=0.1343, entropia=1.387
# T=20: p(gatto)=0.3520, p(cane)=0.1668, p(felino)=0.1600, entropia=1.944
#
# Con T alta, il teacher rivela che "gatto" e molto simile a "cane" e "felino"
# ma niente a che fare con "auto" o "aereo".
# Questa e la DARK KNOWLEDGE che lo student impara!
print("\nAnalisi della dark knowledge:")
probs_t1 = F.softmax(teacher_logits / 1, dim=0)
probs_t4 = F.softmax(teacher_logits / 4, dim=0)
for i, classe in enumerate(classi):
print(f" {classe:8s}: T=1: {probs_t1[i]:.4f}, T=4: {probs_t4[i]:.4f}")
Matematyka destylacji
Strata destylacyjna łączy w sobie dwa terminy z hiperparametrem równoważącym alfa:
- L_destylacja: Rozbieżność KL pomiędzy miękkimi prawdopodobieństwami ucznia i nauczyciela (w temperaturze T), pomnożona przez T² w celu skompensowania zmniejszenia wielkości gradientu
- L_uczeń: Standardowa entropia krzyżowa między przewidywaniami uczniów i twardymi etykietami
Ostateczna formuła to: L = alfa * T² * KL(student_soft || nauczyciel_soft) + (1-alfa) * CE(student, etykiety)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
# ============================================================
# DISTILLATION LOSS - Implementazione completa
# ============================================================
class DistillationLoss(nn.Module):
"""
Loss per Knowledge Distillation (Hinton et al., 2015).
L = alpha * T^2 * KL(student_soft || teacher_soft) + (1-alpha) * CE(student, labels)
Il fattore T^2 e fondamentale: quando si usa T > 1,
i gradienti si riducono di 1/T^2. Moltiplicando per T^2
si compensa questo effetto, mantenendo scale dei gradienti
coerenti tra la KD loss e la CE loss.
"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
"""
temperature: scala le probabilità soft (tipico: 2-8)
alpha: peso della distillation loss (tipico: 0.5-0.9)
"""
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor) -> dict:
# Soft probabilities a temperatura T
# NB: log_softmax per student (richiesto da KLDivLoss)
student_soft = F.log_softmax(student_logits / self.T, dim=1)
teacher_soft = F.softmax(teacher_logits / self.T, dim=1)
# KL divergence * T^2 per compensare la riduzione del gradiente
loss_distill = self.kl_loss(student_soft, teacher_soft) * (self.T ** 2)
# Cross-entropy standard con hard labels
loss_student = self.ce_loss(student_logits, labels)
# Combinazione pesata
total_loss = self.alpha * loss_distill + (1 - self.alpha) * loss_student
return {
'total': total_loss,
'distill': loss_distill.detach(),
'student': loss_student.detach()
}
# ============================================================
# MODELLI TEACHER e STUDENT
# ============================================================
def create_teacher_student(n_classes: int = 100):
"""
Teacher: ResNet-50 (~25M parametri) - pre-trained su ImageNet
Student: MobileNetV3-Small (~2.5M parametri) - 10x più piccolo
"""
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, n_classes)
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, n_classes
)
total_teacher = sum(p.numel() for p in teacher.parameters())
total_student = sum(p.numel() for p in student.parameters())
flops_teacher = 4.1e9 # Approx FLOPs ResNet-50
flops_student = 0.056e9 # Approx FLOPs MobileNetV3-S
print(f"Teacher (ResNet-50): {total_teacher:,} param, {flops_teacher/1e9:.1f}G FLOPs")
print(f"Student (MobileNetV3): {total_student:,} param, {flops_student*1000:.0f}M FLOPs")
print(f"Fattore compressione: {total_teacher/total_student:.1f}x param, "
f"{flops_teacher/flops_student:.0f}x FLOPs")
return teacher, student
# ============================================================
# TRAINING LOOP CON DISTILLAZIONE
# ============================================================
def train_with_distillation(
teacher: nn.Module,
student: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
n_epochs: int = 50,
temperature: float = 4.0,
alpha: float = 0.7,
lr: float = 1e-3,
device: str = "cuda"
):
teacher = teacher.to(device).eval() # Teacher: SOLO inference, no backprop!
student = student.to(device)
criterion = DistillationLoss(temperature=temperature, alpha=alpha)
optimizer = torch.optim.AdamW(student.parameters(), lr=lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
best_acc = 0.0
history = {'train_loss': [], 'val_acc': [], 'distill_loss': [], 'student_loss': []}
for epoch in range(n_epochs):
student.train()
total_loss = distill_loss_sum = student_loss_sum = 0.0
n_batches = 0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
# Forward pass teacher (CRITICO: no gradients, risparmia memoria!)
with torch.no_grad():
teacher_logits = teacher(imgs)
# Forward pass student (con gradients)
student_logits = student(imgs)
# Loss combinata
losses = criterion(student_logits, teacher_logits, labels)
optimizer.zero_grad()
losses['total'].backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
optimizer.step()
total_loss += losses['total'].item()
distill_loss_sum += losses['distill'].item()
student_loss_sum += losses['student'].item()
n_batches += 1
scheduler.step()
# Validation
student.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = student(imgs).argmax(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_total = total_loss / n_batches
history['train_loss'].append(avg_total)
history['val_acc'].append(val_acc)
history['distill_loss'].append(distill_loss_sum / n_batches)
history['student_loss'].append(student_loss_sum / n_batches)
if val_acc > best_acc:
best_acc = val_acc
torch.save(student.state_dict(), "best_student.pth")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {avg_total:.4f} "
f"| Val Acc: {val_acc:.4f} | Best: {best_acc:.4f}")
print(f"\nMiglior accuracy student: {best_acc:.4f}")
return history, best_acc
# Risultati tipici CIFAR-100:
# ResNet-50 teacher: 78.2% Top-1
# MobileNetV3-S senza KD: 67.1% Top-1
# MobileNetV3-S con KD: 71.4% Top-1 (+4.3%)
# MobileNetV3-S con KD+feat: 73.2% Top-1 (+6.1%)
# Compression: 10x param, 73x FLOPs
Destylacja cech: przenoszenie reprezentacji wewnętrznych
Destylacja na miękkich etykietach przekazuje jedynie końcowy dorobek nauczyciela. Tam Funkcja Destylacja idzie dalej: zmusza ucznia do również replikowania reprezentacje pośrednie nauczyciela — mapy cech na różnych poziomach sieci. Jest to szczególnie skuteczne, gdy nauczyciele i uczniowie mają bardzo różne architektury (np. nauczyciel CNN, student ViT) oraz gdy zadanie wymaga bogatych funkcji przestrzennych.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# ============================================================
# FEATURE EXTRACTOR via Forward Hooks
# ============================================================
class FeatureExtractor:
"""Cattura le feature di layer specifici tramite forward hooks."""
def __init__(self, model: nn.Module, layer_names: list):
self.features = {}
self.hooks = []
for name, module in model.named_modules():
if name in layer_names:
hook = module.register_forward_hook(
lambda m, inp, out, n=name: self.features.update({n: out})
)
self.hooks.append(hook)
def get_features(self) -> list:
return list(self.features.values())
def clear(self):
self.features.clear()
def remove(self):
"""Rimuovi hooks per evitare memory leaks."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
# ============================================================
# FEATURE DISTILLATION LOSS
# ============================================================
class FeatureDistillationLoss(nn.Module):
"""
Loss che combina:
1. KD loss standard (soft labels output)
2. Feature Matching Loss (MSE tra feature intermedie normalizzate)
3. Relation-Based Loss (distanze relative tra sample nel batch)
"""
def __init__(self, student_channels: list, teacher_channels: list,
temperature: float = 4.0, alpha: float = 0.4,
beta: float = 0.4, gamma: float = 0.2):
"""
alpha: peso KD loss
beta: peso feature matching loss
gamma: peso CE loss standard
(alpha + beta + gamma deve essere 1.0)
"""
super().__init__()
assert abs(alpha + beta + gamma - 1.0) < 1e-6, "Pesi devono sommare a 1"
self.T = temperature
self.alpha = alpha
self.beta = beta
self.gamma = gamma
# Adattatori per allineare dimensioni teacher->student
# Esempio: teacher ha 2048 canali, student 96 -> adapter 1x1 conv
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Conv2d(t_ch, s_ch, 1, bias=False),
nn.BatchNorm2d(s_ch),
nn.ReLU(inplace=True)
)
for t_ch, s_ch in zip(teacher_channels, student_channels)
])
def forward(self, student_logits, teacher_logits, labels,
student_features: list, teacher_features: list):
# 1. KD Loss (soft labels)
kl = nn.KLDivLoss(reduction='batchmean')
loss_kd = kl(
F.log_softmax(student_logits / self.T, dim=1),
F.softmax(teacher_logits / self.T, dim=1)
) * self.T ** 2
# 2. CE Loss standard
loss_ce = F.cross_entropy(student_logits, labels)
# 3. Feature Matching Loss
loss_feat = torch.tensor(0.0, device=student_logits.device)
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
# Adatta canali del teacher allo student
t_adapted = self.adapters[i](t_feat.detach())
# Allinea risoluzione spaziale se necessario
if s_feat.shape[2:] != t_adapted.shape[2:]:
t_adapted = F.interpolate(
t_adapted, size=s_feat.shape[2:],
mode='bilinear', align_corners=False
)
# Normalizza le feature (cosine similarity invece di MSE)
s_norm = F.normalize(s_feat.view(s_feat.size(0), -1), dim=1)
t_norm = F.normalize(t_adapted.view(t_adapted.size(0), -1), dim=1)
# MSE tra feature normalizzate
loss_feat = loss_feat + F.mse_loss(s_norm, t_norm)
loss_feat = loss_feat / max(len(student_features), 1)
total = self.alpha * loss_kd + self.beta * loss_feat + self.gamma * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'feat': loss_feat.detach()
}
# Configurazione per ResNet-50 teacher -> MobileNetV3-S student
# Layer teacher: [layer2, layer3, layer4] -> Canali: [512, 1024, 2048]
# Layer student: [features.4, features.9, features.12] -> Canali: [40, 96, 576]
teacher = models.resnet50(pretrained=True)
student = models.mobilenet_v3_small(pretrained=False)
teacher_layers = ['layer2', 'layer3', 'layer4']
student_layers = ['features.4', 'features.9', 'features.12']
teacher_channels = [512, 1024, 2048]
student_channels = [40, 96, 576]
teacher_extractor = FeatureExtractor(teacher, teacher_layers)
student_extractor = FeatureExtractor(student, student_layers)
feat_criterion = FeatureDistillationLoss(
student_channels=student_channels,
teacher_channels=teacher_channels,
temperature=4.0, alpha=0.4, beta=0.4, gamma=0.2
)
print("Feature Distillation setup completato!")
print(f"Teacher layers: {teacher_layers}")
print(f"Student layers: {student_layers}")
Przeniesienie uwagi na transformator i transformator wizyjny
Transformatory wizyjne tworzą wyraźne mapy uwagi, które można bezpośrednio analizować destylowana. DeiT (Data-efektywny Image Transformer) wykorzystuje to podejście ze specjalnym żetonem destylacji. THE'Przeniesienie uwagi (Zagoruyko & Komodakis, 2017) rozszerza tę koncepcję także na CNN, budując mapy uwagi z aktywacji warstw splotowych.
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# ============================================================
# ATTENTION TRANSFER (AT) per CNN
# ============================================================
class AttentionTransferLoss(nn.Module):
"""
Attention Transfer (Zagoruyko & Komodakis, 2017).
Forza lo student a replicare le attention maps del teacher.
Efficace per transfer tra architetture diverse (CNN <-> ViT).
"""
def __init__(self, beta: float = 1000.0):
super().__init__()
self.beta = beta
def attention_map(self, features: torch.Tensor) -> torch.Tensor:
"""
Calcola mappa di attention come norma L2 quadrata delle attivazioni.
Input features: [B, C, H, W]
Output: [B, H*W] normalizzato (attention map piatta)
"""
# Somma sui canali -> [B, H, W]
attention = features.pow(2).sum(dim=1)
# Appiattisci -> [B, H*W]
attention = attention.view(attention.size(0), -1)
# Normalizza L2 per ogni sample nel batch
return F.normalize(attention, p=2, dim=1)
def forward(self, student_features: list, teacher_features: list) -> torch.Tensor:
"""Calcola AT loss su più livelli."""
total_loss = torch.tensor(0.0)
for s_feat, t_feat in zip(student_features, teacher_features):
s_attn = self.attention_map(s_feat)
t_attn = self.attention_map(t_feat).detach()
# Allinea dimensioni spaziali se necessario
if s_attn.shape != t_attn.shape:
s_h = int(s_feat.shape[2] * s_feat.shape[3])
t_h = int(t_feat.shape[2] * t_feat.shape[3])
# Usa interpolazione sull'attention map 2D
s_2d = s_feat.pow(2).mean(1, keepdim=True)
t_2d = t_feat.pow(2).mean(1, keepdim=True)
t_2d = F.interpolate(t_2d, size=s_feat.shape[2:], mode='bilinear')
s_attn = F.normalize(s_2d.view(s_2d.size(0), -1), p=2, dim=1)
t_attn = F.normalize(t_2d.view(t_2d.size(0), -1), p=2, dim=1).detach()
total_loss = total_loss + (s_attn - t_attn).pow(2).mean()
return self.beta * total_loss / max(len(student_features), 1)
# ============================================================
# DeiT-STYLE: Distillation Token per Vision Transformer
# ============================================================
class ViTWithDistillationToken(nn.Module):
"""
Aggiunge un distillation token a un ViT standard.
Come in DeiT: il token impara a replicare le predizioni
di un teacher CNN (es. RegNet, ResNet).
Durante inference: media CLS token + dist token.
Durante training: loss su entrambi i token.
"""
def __init__(self, vit_model: nn.Module, n_classes: int, d_model: int = 384):
super().__init__()
self.vit = vit_model
# Token di distillazione learnable
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.dist_token, std=0.02)
# Distillation head (separato dal CLS head)
self.dist_head = nn.Linear(d_model, n_classes)
def forward(self, x: torch.Tensor, return_dist: bool = False):
# Ottieni feature dal ViT
features = self.vit.forward_features(x)
# CLS prediction (predizione principale)
cls_pred = self.vit.head(features[:, 0])
# Dist token prediction (guida del teacher)
dist_pred = self.dist_head(features[:, 1]) # Assumendo dist_token al pos 1
if self.training:
return cls_pred, dist_pred
else:
# Inference: media delle due predizioni
return (cls_pred + dist_pred) / 2.0
def deit_distillation_loss(cls_pred, dist_pred, teacher_pred, labels,
alpha: float = 0.5, temperature: float = 3.0):
"""
Loss DeiT: combina CE hard labels + KD dal teacher CNN.
"""
# Hard label loss sul CLS token
loss_cls = F.cross_entropy(cls_pred, labels)
# Soft label loss sul distillation token
loss_dist = F.kl_div(
F.log_softmax(dist_pred / temperature, dim=1),
F.softmax(teacher_pred / temperature, dim=1),
reduction='batchmean'
) * temperature ** 2
return alpha * loss_dist + (1 - alpha) * loss_cls
Destylacja dla LLM: od dużych do małych modeli
Destylacja w przypadku LLM opiera się na tych samych zasadach, ale ma pewne ważne cechy szczególne. Słownictwo jest ogromne (32–128 tys. tokenów), modele nauczycieli i uczniów muszą takie być kompatybilny na poziomie tokenizera, a strata działa na poziomie każdego tokena w sekwencji. Przykładami udanych rozwiązań są DistilBERT, DistilGPT2 i rodzina Phi firmy Microsoft.
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
AutoModelForSequenceClassification
)
import torch
import torch.nn.functional as F
# ============================================================
# DISTILLAZIONE LLM: causal language modeling
# ============================================================
def distill_llm_batch(
teacher_model,
student_model,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 2.0,
alpha: float = 0.7
) -> dict:
"""
Distillazione LLM per next-token prediction.
Funziona per GPT-style (causal) e BERT-style (masked).
teacher_model: modello grande (es. Llama-3-8B)
student_model: modello piccolo (es. Llama-3-1B)
alpha: peso KD loss (1-alpha = peso CE loss standard)
"""
device = next(student_model.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Teacher inference (no gradients, può essere su device diverso)
with torch.no_grad():
teacher_outputs = teacher_model(
input_ids, attention_mask=attention_mask
)
teacher_logits = teacher_outputs.logits # [B, seq_len, vocab_size]
# Student inference (con gradients)
student_outputs = student_model(
input_ids, attention_mask=attention_mask
)
student_logits = student_outputs.logits
# Shift per next-token prediction (esclude l'ultimo token come input)
shift_student = student_logits[:, :-1, :].contiguous()
shift_teacher = teacher_logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
# Ridimensiona per calcolo per-token loss
B, S, V = shift_student.shape
shift_student_flat = shift_student.view(B * S, V)
shift_teacher_flat = shift_teacher.view(B * S, V)
shift_labels_flat = shift_labels.view(B * S)
# 1. KD Loss: KL divergence per ogni token
student_log_probs = F.log_softmax(shift_student_flat / temperature, dim=-1)
teacher_probs = F.softmax(shift_teacher_flat / temperature, dim=-1)
loss_kd = F.kl_div(student_log_probs, teacher_probs,
reduction='batchmean') * temperature ** 2
# 2. CE Loss standard (con label -100 per token da ignorare)
loss_ce = F.cross_entropy(
shift_student_flat, shift_labels_flat,
ignore_index=-100 # Padding tokens
)
total = alpha * loss_kd + (1 - alpha) * loss_ce
return {
'total': total,
'kd': loss_kd.detach(),
'ce': loss_ce.detach(),
'perplexity': torch.exp(loss_ce).detach()
}
# ============================================================
# PIPELINE DISTILLAZIONE LLM COMPLETA
# ============================================================
def setup_llm_distillation(
teacher_name: str = "meta-llama/Llama-3.1-8B",
student_name: str = "meta-llama/Llama-3.2-1B",
device: str = "cuda"
):
"""
Setup per distillare un LLM grande in uno più piccolo.
IMPORTANTE: teacher e student devono condividere il tokenizer
per avere distribuzioni compatibili sullo stesso vocabolario.
"""
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Teacher: carica in FP16 per risparmiare memoria
teacher = AutoModelForCausalLM.from_pretrained(
teacher_name,
torch_dtype=torch.float16,
device_map="auto" # Distribuisce su più GPU se disponibili
)
teacher.eval()
# Student: carica in FP32 per training stabile
student = AutoModelForCausalLM.from_pretrained(
student_name,
torch_dtype=torch.float32
).to(device)
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {teacher_params/1e9:.1f}B parametri")
print(f"Student: {student_params/1e9:.1f}B parametri")
print(f"Compressione: {teacher_params/student_params:.1f}x")
return teacher, student, tokenizer
print("Setup distillazione LLM completato!")
Sieci samodestylacji i Born Again
La samodestylacja i zaskakujący wariant: modelka pełni rolę nauczyciela do siebie. W Sieci narodzone na nowo (BANs, Furlanello i in. 2018), szkolą kolejne generacje modeli o tej samej architekturze: każda generacja korzysta z poprzednio jako nauczyciel. Rezultatem jest systematyczna poprawa (+1-2% Top-1) bez zwiększyć rozmiar modelu.
import torch
import torch.nn as nn
import copy
# ============================================================
# BORN AGAIN NETWORKS (BANs)
# ============================================================
def born_again_training(model_factory, train_loader, val_loader,
n_generations: int = 3,
temperature: float = 4.0,
n_epochs: int = 30,
device: str = "cuda"):
"""
Allena N generazioni con la stessa architettura.
Gen 1: training standard con CE loss.
Gen 2+: distillazione dalla generazione precedente.
Risultati tipici CIFAR-100:
Gen 1: 67.1% (baseline)
Gen 2: 70.8% (+3.7%)
Gen 3: 72.1% (+5.0%)
Ensemble 1+2+3: 74.8% (+7.7%)
"""
criterion_kd = DistillationLoss(temperature=temperature, alpha=0.7)
criterion_ce = nn.CrossEntropyLoss()
all_models = []
results = []
# === Generazione 1: training standard ===
print("Gen 1: training standard...")
gen1 = model_factory().to(device)
opt1 = torch.optim.SGD(gen1.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch1 = torch.optim.lr_scheduler.MultiStepLR(opt1, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
gen1.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
opt1.zero_grad()
criterion_ce(gen1(imgs), labels).backward()
opt1.step()
sch1.step()
acc1 = _evaluate(gen1, val_loader, device)
results.append(acc1)
all_models.append(gen1)
print(f" Gen 1 val acc: {acc1:.4f}")
teacher = copy.deepcopy(gen1).eval()
# === Generazioni successive con distillazione ===
for gen_idx in range(2, n_generations + 1):
print(f"Gen {gen_idx}: KD da gen {gen_idx-1}...")
student = model_factory().to(device)
opt = torch.optim.SGD(student.parameters(), lr=0.1,
momentum=0.9, weight_decay=5e-4)
sch = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[15, 25], gamma=0.1)
for epoch in range(n_epochs):
student.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
with torch.no_grad():
t_logits = teacher(imgs)
s_logits = student(imgs)
losses = criterion_kd(s_logits, t_logits, labels)
opt.zero_grad()
losses['total'].backward()
opt.step()
sch.step()
acc = _evaluate(student, val_loader, device)
results.append(acc)
all_models.append(student)
print(f" Gen {gen_idx} val acc: {acc:.4f}")
teacher = copy.deepcopy(student).eval()
# Ensemble di tutti i modelli (upper bound)
ensemble_acc = _ensemble_evaluate(all_models, val_loader, device)
print(f"\nEnsemble {n_generations} gen: {ensemble_acc:.4f}")
return results, all_models
def _evaluate(model, loader, device):
model.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
correct += (model(imgs).argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
def _ensemble_evaluate(models, loader, device):
"""Ensemble averaging delle predizioni."""
for m in models:
m.eval()
correct = total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs, labels = imgs.to(device), labels.to(device)
logits_sum = torch.stack([m(imgs) for m in models]).mean(0)
correct += (logits_sum.argmax(1) == labels).sum().item()
total += labels.size(0)
return correct / total
Typowe wyniki destylacji (benchmark 2024–2025)
| Zadania | Nauczyciel | Student | Bez KD | Z KD | Nauczyciel | Kompresja |
|---|---|---|---|---|---|---|
| CIFAR-100 | ResNet-50 | MobileNetV3-S | 67,1% | 73,2% | 78,2% | 10x param |
| ImageNet | ViT-L/16 | DeiT-S | 79,8% | 83,1% | 87,1% | 5x param |
| KLEJ (NLP) | BERT-duży | DestylBERT | 83,2% | 86,4% | 89,2% | 2x parametr, 2x prędkość |
| SKŁAD (QA) | ROBERTa-L | DestylujRoBERTa | 82,1% | 85,8% | 90,4% | 2x param |
| LLM (zakłopotanie) | Lama 3.1 8B | Lama 3.2 1B | 8,24 PPL | 7,81 PPL | 6.12 PPL | 8x param |
KD zazwyczaj odzyskuje 70–85% luki między uczniem a nauczycielem przy 2–10 razy mniejszej liczbie parametrów.
Rurociąg produkcyjny: destylacja + oznaczanie ilościowe
Najpotężniejszy przepływ pracy przy wdrażaniu brzegowym łączy w sobie destylację i kwantyzację sekwencja: najpierw utwórz ucznia za pomocą KD (zachowując wysoką dokładność), a następnie skwantuj ucznia (zmniejsza rozmiar i zwiększa prędkość). Kombinacja może zmniejszyć model od 100x do 40x w porównaniu do oryginalnego nauczyciela, przy jedynie 5-10% utracie dokładności.
import torch
import torch.nn as nn
from torchvision import models
# ============================================================
# PIPELINE COMPLETA: Distillazione -> Quantizzazione -> ONNX
# ============================================================
def full_compression_pipeline(teacher_path: str, output_dir: str = "./compressed"):
"""
Pipeline completa per comprimere un modello per edge deployment.
Step 1: Carica teacher pre-trainato
Step 2: Distilla in student più piccolo
Step 3: Quantizza lo student (PTQ INT8)
Step 4: Esporta in ONNX per deployment cross-platform
"""
import os
os.makedirs(output_dir, exist_ok=True)
# STEP 1: Teacher
print("Step 1: Carico teacher...")
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(2048, 10) # 10 classi
# In produzione: teacher.load_state_dict(torch.load(teacher_path))
teacher.eval()
teacher_size_mb = sum(p.numel() * p.element_size()
for p in teacher.parameters()) / (1024**2)
print(f" Teacher: {teacher_size_mb:.1f} MB, "
f"{sum(p.numel() for p in teacher.parameters())/1e6:.1f}M param")
# STEP 2: Student (dopo distillazione)
print("Step 2: Student con distillazione (simulato con MobileNetV3)...")
student = models.mobilenet_v3_small(pretrained=False)
student.classifier[3] = nn.Linear(
student.classifier[3].in_features, 10
)
# In produzione: train_with_distillation(teacher, student, ...)
# student.load_state_dict(torch.load("best_student.pth"))
student_size_mb = sum(p.numel() * p.element_size()
for p in student.parameters()) / (1024**2)
print(f" Student: {student_size_mb:.1f} MB, "
f"{sum(p.numel() for p in student.parameters())/1e6:.1f}M param")
print(f" Riduzione rispetto teacher: {teacher_size_mb/student_size_mb:.1f}x")
# STEP 3: Quantizzazione INT8 (PTQ)
print("Step 3: Quantizzazione INT8...")
student.eval()
# Quantizzazione dinamica (più semplice, leggermente meno efficiente)
student_quant = torch.quantization.quantize_dynamic(
student,
{nn.Linear}, # Quantizza solo Linear (Conv2d richiede calibrazione)
dtype=torch.qint8
)
# Per quantizzazione statica (più efficiente):
# student.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# torch.quantization.prepare(student, inplace=True)
# calibrate(student, calib_loader) # Forward pass su dati di calibrazione
# torch.quantization.convert(student, inplace=True)
quant_size_mb = sum(p.numel() * p.element_size()
for p in student_quant.parameters()) / (1024**2)
print(f" Student INT8: ~{student_size_mb/4:.1f} MB (stima)")
print(f" Riduzione totale: ~{teacher_size_mb/(student_size_mb/4):.0f}x vs teacher")
# STEP 4: Export ONNX
print("Step 4: Export ONNX...")
dummy = torch.randn(1, 3, 224, 224)
onnx_path = f"{output_dir}/student_compressed.onnx"
torch.onnx.export(
student, # Usa FP32 per ONNX (compatibilità più ampia)
dummy,
onnx_path,
opset_version=13,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}},
export_params=True
)
print(f"\n=== RIEPILOGO PIPELINE ===")
print(f"Teacher (ResNet-50): {teacher_size_mb:.1f} MB")
print(f"Student KD (MobNetV3): {student_size_mb:.1f} MB ({teacher_size_mb/student_size_mb:.1f}x riduzione)")
print(f"Student INT8 (stimato): {student_size_mb/4:.1f} MB ({teacher_size_mb/(student_size_mb/4):.0f}x riduzione)")
print(f"ONNX salvato: {onnx_path}")
return student_quant, onnx_path
full_compression_pipeline("teacher_weights.pth")
Anty-wzorce w destylacji: częste błędy
- Temperatury zbyt wysokie lub zbyt niskie: T=1 oznacza twarde etykiety. Zbyt wysokie T (>20) powoduje, że miękkie etykiety są prawie jednolite, tracąc sygnał. Zawsze wykonuj badanie ablacji z T ∈ {2, 4, 6, 8} na konkretnym zestawie danych.
- Nauczyciel i uczeń zbyt różnią się zdolnościami: jeśli różnica jest ogromna (GPT-4 do 7B), destylacja bezpośrednia jest nieskuteczna. Użyj destylacji wodospadowej: GPT-4 -> 13B -> 7B -> 3B. Każdy krok nie powinien przekraczać 4-5x redukcji.
- Ignorując jakość zbioru danych dotyczących destylacji: jakość zbiór danych, na którym przeprowadzasz destylację, ma ogromny wpływ. Korzystaj z różnorodnych danych i reprezentatywny dla docelowego zadania. Dane spoza dystrybucji zakłócają transfer.
- Źle skalibrowana alfa: przy alfa=1 (tylko KD) uczeń ignoruje podstawowe etykiety prawdy i może generować niestabilne przewidywania, jeśli nauczyciel popełni błąd. Przy alfa=0 następuje redukcja do standardowego treningu. Wartości 0,5-0,8 są zazwyczaj optymalne.
- Nie zamrażaj nauczyciela: nauczyciel musi być w trybie eval(). podczas szkolenia ucznia. Jeśli nauczyciel ciągle się zmienia (np. ma BatchNorm w trybie pociągu) cele destylacji są niespójne, a szkolenie jest rozbieżne.
Destylacja dla LLM: dekodowanie spekulatywne i destylacja odpowiedzi
W kontekście modeli wielkojęzykowych destylacja przybiera nowe i potężne formy. Dwie techniki, które są szczególnie istotne w latach 2025–2026, to: Destylacja odpowiedzi (używany do trenowania modeli Llama-3.2 i Microsoft Phi) i lo Dekodowanie spekulatywne, który wykorzystuje mały model w celu przyspieszenia wnioskowania dużego modelu bez utraty jakości.
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
# ============================================================
# SPECULATIVE DECODING: Draft Model + Target Model
# ============================================================
# Principio: un modello piccolo (draft) genera K token in anticipo.
# Il modello grande (target) verifica tutti i K token in un solo forward pass.
# Se il draft ha ragione, si risparmiano K-1 forward pass del modello grande.
# Speedup tipico: 2-4x senza perdita di qualità.
class SpeculativeDecoder:
"""
Implementazione base di speculative decoding.
Draft model: modello piccolo (es. Llama-3.2-1B)
Target model: modello grande (es. Llama-3.1-8B)
Riferimento: "Fast Inference from Transformers via Speculative Decoding"
(Leviathan et al., 2022) - il paper originale di Google.
"""
def __init__(
self,
draft_model_name: str,
target_model_name: str,
device: str = "cuda",
lookahead_k: int = 5 # Token generati dal draft per ogni step
):
self.device = device
self.lookahead_k = lookahead_k
print(f"Carico draft model: {draft_model_name}...")
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name, torch_dtype=torch.float16
).to(device).eval()
print(f"Carico target model: {target_model_name}...")
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, torch_dtype=torch.float16
).to(device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
def draft_generate(self, input_ids: torch.Tensor) -> tuple:
"""
Il draft model genera K token e restituisce
le distribuzioni di probabilità per acceptance/rejection.
"""
draft_ids = input_ids.clone()
draft_probs = []
with torch.no_grad():
for _ in range(self.lookahead_k):
out = self.draft_model(draft_ids)
next_logits = out.logits[:, -1, :]
next_probs = F.softmax(next_logits, dim=-1)
# Campiona dal draft
next_token = torch.multinomial(next_probs, num_samples=1)
draft_probs.append(next_probs)
draft_ids = torch.cat([draft_ids, next_token], dim=1)
# draft_ids ora include K token aggiuntivi
draft_tokens = draft_ids[:, input_ids.shape[1]:]
return draft_tokens, draft_probs
def speculative_generate(
self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.7
) -> str:
"""Genera testo con speculative decoding."""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
generated = input_ids.clone()
total_accepted = 0
total_draft = 0
with torch.no_grad():
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# 1. Draft genera K token
draft_tokens, draft_probs = self.draft_generate(generated)
total_draft += self.lookahead_k
# 2. Target verifica tutti i K+1 token in un forward pass
full_seq = torch.cat([generated, draft_tokens], dim=1)
target_out = self.target_model(full_seq)
target_logits = target_out.logits[:, generated.shape[1]-1:-1, :]
# 3. Acceptance-rejection sampling
n_accepted = 0
for i in range(draft_tokens.shape[1]):
draft_tok = draft_tokens[0, i].item()
target_probs = F.softmax(target_logits[0, i] / temperature, dim=-1)
draft_p = draft_probs[i][0, draft_tok].item()
target_p = target_probs[draft_tok].item()
# Accetta se target d'accordo con draft
r = torch.rand(1).item()
if r < min(1.0, target_p / (draft_p + 1e-10)):
n_accepted += 1
else:
# Rifiuta: campiona dal target corretto
corrected = torch.multinomial(
F.relu(target_probs - draft_probs[i][0]),
num_samples=1
)
generated = torch.cat([
generated,
draft_tokens[:, :i],
corrected.unsqueeze(0)
], dim=1)
break
else:
# Tutti accettati: aggiungi bonus token dal target
generated = torch.cat([generated, draft_tokens], dim=1)
bonus_logits = target_out.logits[:, -1, :]
bonus_tok = torch.multinomial(
F.softmax(bonus_logits / temperature, dim=-1), 1
)
generated = torch.cat([generated, bonus_tok], dim=1)
total_accepted += n_accepted
acceptance_rate = total_accepted / max(total_draft, 1)
print(f"Acceptance rate: {acceptance_rate:.1%} (atteso 60-80% con draft simile)")
new_tokens = generated[0, input_ids.shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# ============================================================
# RESPONSE DISTILLATION per LLM (semplificata)
# ============================================================
# Tecnica usata da Llama-3.2, Phi-3, Mistral-7B-Instruct:
# 1. Teacher LLM grande (es. GPT-4, Llama-3.1-70B) genera risposte
# 2. Student LLM piccolo impara a imitare quelle risposte
# Diverso dalla distillazione classica: distilla risposte (output testo),
# non distribuzioni di probabilità interne.
def response_distillation_dataset(
teacher_model_name: str,
prompts: list,
output_file: str = "distillation_dataset.jsonl"
) -> list:
"""
Genera dataset di distillazione con risposte del teacher.
In produzione: usa GPT-4 API, Llama-3.1-70B, o Claude.
"""
import json
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
device_map="auto"
).eval()
dataset = []
with open(output_file, 'w') as f:
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(teacher.device)
with torch.no_grad():
outputs = teacher.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
response_ids = outputs[0, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
entry = {"prompt": prompt, "response": response}
dataset.append(entry)
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"Dataset generato: {len(dataset)} esempi -> {output_file}")
return dataset
# Note: in pratica, usa l'API di un servizio commerciale (OpenAI, Anthropic)
# per generare le risposte del "teacher", poi addestra lo student su di esse.
# Questa e la tecnica dietro la maggior parte dei modelli instruction-following
# come Alpaca, Vicuna, Orca, e i modelli Phi di Microsoft.
print("LLM distillation patterns pronti")
Destylacja: porównanie techniczne wariantów (2024–2025)
| Wariant | Co destyluje | Zawodowiec | Przeciwko | Typowe zastosowanie |
|---|---|---|---|---|
| Etykiety miękkie (Hinton 2015) | Rozkłady prawdopodobieństwa | Bogate, standardowe informacje | Wymaga dostępu do logów nauczyciela | Wizja, klasyfikacja |
| Funkcja Destylacja | Reprezentacje pośrednie | Głęboki transfer funkcji | Nauczyciel i uczeń muszą mieć kompatybilną architekturę | Detekcja, segmentacja |
| Destylacja odpowiedzi | Dorobek tekstowy nauczyciela | Nie wymaga dostępu wewnętrznego | Traci informacje o niepewności | Postępowanie zgodnie z instrukcjami LLM |
| Sieci narodzonych na nowo | Iteracyjna samodestylacja | Nie wymaga osobnego nauczyciela | Ograniczony zysk, wysoki koszt obliczeniowy | Zespół, ulepszenie |
| Dekodowanie spekulatywne | To nie jest destylacja, ale wykorzystuje przeciąg | Przyspieszenie 2-4x bez strat | Wymaga dwóch modeli w pamięci | Przyspieszenie wnioskowania LLM |
Wnioski
Destylacja wiedzy jest jedną z najpotężniejszych i najbardziej wszechstronnych technik kompresji dostępny w 2026 r. Połącz naturalnie z kwantyzacją i przycinaniem: najpierw destyluj aby utworzyć optymalnego ucznia, a następnie skwantyfikować ucznia na potrzeby wdrożenia brzegowego. Wynik i często model 10-100x mniejszy od nauczyciela z jedynie 5-15% utratą dokładności.
W przypadku LLM destylacja umożliwiła całą rodzinę modeli „Distil*”: DistilBERT, DistilGPT2 i modele Phi firmy Microsoft (2,7B z jakością modelu 7B). Trend 2026 roku — Modele małego języka (SLM), które przewyższają modele LLM pod względem częstotliwości użycia zgodnie z Gartnera — a stało się to możliwe właśnie dzięki destylacji, która przekazuje wiedzę nt gigantów po modele działające na Raspberry Pi i smartfonach.
W następnym artykule pokazano, jak wdrożyć te skompresowane modele na platformie urządzenia brzegowe: Raspberry Pi, NVIDIA Jetson i sprzęt wbudowany, ze wszystkimi niezbędnymi optymalizacjami dla prawdziwego środowiska produkcyjnego.
Następne kroki
- Następny artykuł: Głębokie uczenie się na urządzeniach brzegowych: od chmury do krawędzi
- Powiązany: Kwantyzacja modelu: GPTQ, AWQ, INT8
- Powiązany: Przycinanie sieci neuronowych: redukcja parametrów
- Powiązany: Vision Transformer: destylacja za pomocą DeiT
- Seria MLOps: Udostępnianie skompresowanych modeli w produkcji







