Pruning delle Reti Neurali: Ridurre la Complessità dei Modelli
Un modello ResNet-50 ha oltre 25 milioni di parametri. GPT-3 ne ha 175 miliardi. Eppure ricerche sistematiche dimostrano che la maggior parte di questi parametri e ridondante: reti neurali addestrate possono perdere oltre il 90% dei propri pesi senza degradazione significativa dell'accuratezza. Il pruning — la tecnica di rimozione sistematica dei parametri superflui — e uno degli strumenti più potenti per ridurre la complessità computazionale dei modelli di deep learning.
A differenza della quantizzazione, che riduce la precisione numerica dei parametri, il pruning li elimina completamente. Il risultato può essere un modello più piccolo, più veloce e meno costoso da eseguire — soprattutto quando si applica il pruning strutturato, che rimuove interi neuroni, filtri o attention heads, producendo speedup reali sull'hardware senza richiedere supporto per la sparsita.
In questa guida esploriamo il pruning a fondo: dalla teoria della Lottery Ticket Hypothesis alle implementazioni pratiche con PyTorch, dal pruning per magnitudine al movement pruning per Transformer, fino ai workflow iterativi e alle combinazioni con la quantizzazione.
Cosa Imparerai
- Differenza tra pruning strutturato e non strutturato, e quando usare ciascuno
- Magnitude pruning: il metodo più semplice ed efficace
- Movement pruning per Transformer e LLM moderni
- Lottery Ticket Hypothesis: la teoria che spiega perchè il pruning funziona
- API di pruning di PyTorch con esempi completi
- Workflow di pruning iterativo con retraining
- Torch-Pruning per pruning strutturato avanzato
- Combinazione pruning + quantizzazione per massima compressione
- Benchmark reali di accuratezza, memoria e velocità
- Best practices e anti-pattern comuni
perchè il Pruning? Il Problema della Ridondanza
Le reti neurali moderne sono notoriamente sovra-parametrizzate. Questa ridondanza e in parte intenzionale: reti più grandi si addestrano più facilmente e generalizzano meglio, ma al momento del deployment si porta dietro peso computazionale inutile. Tre osservazioni empiriche fondamentali motivano il pruning:
- Ridondanza dei pesi: Studi di degenerate pruning dimostrano che in reti addestrate la distribuzione dei pesi e fortemente concentrata attorno allo zero. Rimuovere i pesi di minore magnitudine ha impatto minimo sulle predizioni.
- Lottery Ticket Hypothesis (Frankle & Carlin, 2019): Ogni rete neurale addestrata contiene una sottoreti "vincente" che, se re-inizializzata con i valori originali e addestrata da sola, raggiunge performance comparabili alla rete completa.
- Over-parameterization come strumento: I parametri extra servono per il training (landscape più smooth, escape dai minimi locali), ma non sono necessari per l'inferenza.
Impatto del Pruning: Dati Reali
Ricerche su ResNet e BERT mostrano che modelli possono perdere il 70-90% dei parametri con perdita di accuratezza inferiore all'1-2%. Il pruning strutturato di Transformer su BERT-base con sparsity 50% produce una riduzione di FLOPs del 2x e uno speedup di inferenza di 1.5x mantenendo oltre il 99% dell'accuratezza originale. Nel contesto LLM, le tecniche di block pruning per Transformers hanno dimostrato speedup fino a 2.4x su SQuAD con solo 1% di perdita di F1.
Pruning Strutturato vs Non Strutturato
La distinzione fondamentale nel pruning e tra approcci strutturati e non strutturati. La scelta dipende dall'hardware target e dagli obiettivi di deployment:
| Aspetto | Non Strutturato | Strutturato |
|---|---|---|
| Cosa rimuove | Singoli pesi (arbitrari) | Neuroni, filtri, canali, attention heads, layer |
| Sparsita risultante | Irregolare (matrice sparsa) | Regolare (dimensioni ridotte) |
| Speedup reale su CPU/GPU standard | Nessuno (senza sparse ops) | Si, immediato con dense ops |
| Speedup su hardware sparse (CPU sparse, Cerebras) | Si | Si |
| Riduzione memoria | Solo con formato sparso esplicito | Sempre (dimensioni ridotte) |
| Accuratezza a parita di sparsita | Migliore | Leggermente inferiore |
| Complessità implementazione | Semplice | Più complessa (ricalcolo dipendenze) |
Il pruning non strutturato e più flessibile: può rimuovere qualsiasi peso indipendentemente dalla sua posizione. Il problema e che la matrice risultante rimane densa in memoria (zero espliciti), e l'hardware moderno non trae beneficio dalla sparsita irregolare senza supporto specifico (NVIDIA ha introdotto supporto per sparsita 2:4 con le Ampere GPUs, ma richiede pattern specifici). Il pruning strutturato, rimuovendo strutture complete, produce modelli più piccoli in modo verificabile: un layer Linear con 512 neuroni pruned a 256 diventa semplicemente un Linear(in, 256), eseguito con operazioni dense standard.
Magnitude Pruning: Il Metodo Fondamentale
Il magnitude pruning e l'approccio più semplice e sorprendentemente efficace: rimuovere i pesi con valore assoluto minore di una soglia. La logica intuitiva e che i pesi piccoli contribuiscono poco al segnale trasmesso dalla rete. Nonostante la sua semplicità, quando combinato con retraining iterativo produce risultati competitivi con metodi molto più sofisticati.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import numpy as np
# ===================================================================
# MAGNITUDE PRUNING CON PYTORCH NATIVE API
# ===================================================================
class ConvNet(nn.Module):
"""Modello CNN semplice per dimostrare il pruning."""
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(4)
)
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
model = ConvNet()
# --- Pruning L1 non strutturato (magnitude-based) ---
# Rimuove il 30% dei pesi con valore assoluto minore
prune.l1_unstructured(
model.features[0], # Layer da pruning
name='weight', # Parametro da prunare
amount=0.30 # Percentuale da rimuovere (30%)
)
# --- Pruning Random (baseline di confronto) ---
prune.random_unstructured(
model.features[2],
name='weight',
amount=0.30
)
# --- Analisi sparsita risultante ---
def compute_sparsity(module):
"""Calcola la sparsita effettiva di un modulo."""
total = 0
zeros = 0
for param in module.parameters():
total += param.numel()
zeros += (param == 0).sum().item()
return zeros / total if total > 0 else 0.0
print("Sparsita Conv1:", f"{compute_sparsity(model.features[0]):.1%}")
print("Sparsita Conv2:", f"{compute_sparsity(model.features[2]):.1%}")
# --- Verifica la struttura interna del pruning ---
# PyTorch crea weight_orig (originale) + weight_mask (0/1)
print("\nParametri di model.features[0] dopo pruning:")
for name, param in model.features[0].named_parameters():
print(f" {name}: shape={param.shape}")
for name, buf in model.features[0].named_buffers():
print(f" buffer {name}: shape={buf.shape}")
# --- Rimozione della maschera (make permanent) ---
# Dopo retraining, si consolida: il modello torna a usare 'weight'
prune.remove(model.features[0], 'weight')
print("\nDopo prune.remove: parametri di model.features[0]:")
for name, _ in model.features[0].named_parameters():
print(f" {name}")
# --- Global Pruning: pruning globale su tutto il modello ---
# Più efficace del pruning per-layer: usa una soglia globale
parameters_to_prune = (
(model.features[0], 'weight'),
(model.features[2], 'weight'),
(model.classifier[0], 'weight'),
(model.classifier[2], 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.40, # Rimuove 40% globalmente (non per layer)
)
# Sparsita finale per layer
for module_name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
if hasattr(module, 'weight_mask'):
sparsity = (module.weight_mask == 0).float().mean().item()
print(f"{module_name}: sparsita {sparsity:.1%}")
Attenzione: Il Pruning PyTorch Native Non Velocizza l'Inferenza
L'API torch.nn.utils.prune applica una maschera binaria sui pesi, zerando
quelli selezionati ma mantenendo la struttura densa originale. Il modello risultante occupa la
stessa memoria e impiega lo stesso tempo per il forward pass. Per ottenere speedup reali servono:
pruning strutturato (con rimozione fisica delle strutture), o librerie specifiche per sparse
operations. Il pruning PyTorch native e ottimo per sperimentare e per QAT (Quantization-Aware
Training) con sparsita, ma non per deployment diretto.
Pruning Strutturato con Torch-Pruning
La libreria Torch-Pruning (Fang et al., CVPR 2023) risolve il problema del pruning strutturato reale: rimuovere un filtro da un layer Conv2D richiede di aggiornare anche il layer successivo (che si aspetta N canali in input, non N-k). Torch-Pruning gestisce automaticamente queste dipendenze tramite un grafo delle dipendenze (DepGraph), supportando architetture complesse inclusi ViT, LLM, YOLO e modelli con skip connections.
# pip install torch-pruning
import torch
import torch.nn as nn
import torch_pruning as tp
# ===================================================================
# PRUNING STRUTTURATO CON TORCH-PRUNING
# ===================================================================
class ResidualBlock(nn.Module):
"""Blocco residuale: Torch-Pruning gestisce la skip connection automaticamente."""
def __init__(self, channels=64):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return self.relu(out + residual) # Skip connection
class SimpleResNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.layer1 = ResidualBlock(64)
self.layer2 = ResidualBlock(64)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.pool(x).view(x.size(0), -1)
return self.fc(x)
model = SimpleResNet()
model.eval()
# Input di esempio per tracciare le dipendenze
example_input = torch.randn(1, 3, 32, 32)
# --- Costruzione del grafo delle dipendenze ---
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_input)
# --- Analisi del modello PRIMA del pruning ---
macs_before, params_before = tp.utils.count_ops_and_params(model, example_input)
print(f"Parametri PRIMA: {params_before / 1e6:.2f}M")
print(f"MACs PRIMA: {macs_before / 1e9:.3f}G")
# --- Definizione della strategia di pruning ---
# Pruning per magnitudine L1 dei filtri (L2 disponibile con tp.strategy.L2Strategy)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=example_input,
importance=tp.importance.MagnitudeImportance(p=1), # L1 norm
iterative_steps=5, # Pruning iterativo in 5 step
ch_sparsity=0.5, # Rimuove il 50% dei canali
ignored_layers=[model.fc], # Non pruning il classificatore finale
)
# --- Esecuzione del pruning (un singolo step) ---
pruner.step()
# --- Analisi del modello DOPO il pruning ---
macs_after, params_after = tp.utils.count_ops_and_params(model, example_input)
print(f"\nParametri DOPO: {params_after / 1e6:.2f}M")
print(f"MACs DOPO: {macs_after / 1e9:.3f}G")
print(f"Riduzione parametri: {(1 - params_after/params_before):.1%}")
print(f"Riduzione MACs: {(1 - macs_after/macs_before):.1%}")
# --- Verifica architettura post-pruning ---
print("\nArchitettura post-pruning:")
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
print(f" {name}: Conv2d({module.in_channels}, {module.out_channels}, ...)")
# Output tipico:
# Parametri PRIMA: 0.15M | MACs PRIMA: 0.009G
# Parametri DOPO: 0.04M | MACs DOPO: 0.003G
# Riduzione parametri: 75% | Riduzione MACs: 72%
# layer1.conv1: Conv2d(32, 32, ...) <- da 64 a 32 canali
Movement Pruning per Transformer
Il magnitude pruning funziona bene per le CNN, ma i Transformer presentano una sfida diversa: i pesi di attention possono avere magnitudini basse ma essere critici per il comportamento del modello. Il movement pruning (Sanh et al., 2020) affronta questo problema con un approccio radicalmente diverso: invece di rimuovere i pesi piccoli, rimuove quelli che si stanno avvicinando a zero durante il fine-tuning. In altre parole, il criterio e il gradiente del peso rispetto all'obiettivo di pruning, non il valore corrente del peso.
Il movement pruning ha dimostrato vantaggi significativi per il pruning di modelli BERT: a sparsita elevata (80-97%), movement pruning supera il magnitude pruning di 10-20 punti percentuali su benchmark NLP come MNLI e SQuAD.
# Movement Pruning per Transformer con Hugging Face + SparseML
# pip install transformers datasets sparseml
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.optim import AdamW
# ===================================================================
# MOVEMENT PRUNING MANUALE (concetto base)
# ===================================================================
class MovementPruningLinear(nn.Module):
"""
Layer Linear con movement pruning.
Mantiene uno score per ogni peso: lo score viene ottimizzato
durante il training. I pesi con score basso vengono pruned.
"""
def __init__(self, in_features, out_features, pruning_ratio=0.5):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
self.bias = nn.Parameter(torch.zeros(out_features))
# Score inizializzati a zero: durante il training salgono per i pesi importanti
self.scores = nn.Parameter(torch.zeros_like(self.weight))
self.pruning_ratio = pruning_ratio
self.mask = None
def update_mask(self):
"""Aggiorna la maschera basandosi sugli score correnti."""
k = int(self.scores.numel() * (1 - self.pruning_ratio))
# Top-k scores: mantieni i pesi con score più alto
threshold = torch.kthvalue(self.scores.flatten(), self.scores.numel() - k).values
self.mask = (self.scores >= threshold).float().detach()
def forward(self, x):
# Applica la maschera durante il forward pass
if self.mask is None:
self.update_mask()
masked_weight = self.weight * self.mask
return nn.functional.linear(x, masked_weight, self.bias)
# ===================================================================
# PRUNING PRATICO CON TRANSFORMERS + torch.nn.utils.prune
# ===================================================================
def prune_transformer_attention_heads(model, heads_to_prune):
"""
Pruna specifici attention heads da un modello BERT-like.
heads_to_prune: dict {layer_idx: [head_idx_1, head_idx_2, ...]}
"""
model.prune_heads(heads_to_prune)
return model
# Esempio: pruning degli attention heads meno importanti
# Identificazione heads da pruning (basata su Taylor importance)
def compute_head_importance(model, dataloader, device):
"""
Calcola l'importanza di ogni attention head usando Taylor expansion.
Un head e importante se rimuoverlo aumenta molto la loss.
"""
model.eval()
head_importance = torch.zeros(
model.config.num_hidden_layers,
model.config.num_attention_heads
).to(device)
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch, output_attentions=True)
loss = outputs.loss
loss.backward()
# Accumula gradienti per stimare l'importanza
for layer_idx, layer in enumerate(model.bert.encoder.layer):
attn = layer.attention.self
# Importanza approssimata: |grad * weight| sommato per head
grad = attn.value.weight.grad
weight = attn.value.weight
if grad is not None:
importance = (grad * weight).abs().view(
model.config.num_attention_heads, -1
).sum(dim=-1)
head_importance[layer_idx] += importance
return head_importance
# ===================================================================
# STRUCTURED PRUNING DI ATTENTION HEADS CON BERT
# ===================================================================
model_name = "bert-base-uncased"
# model = AutoModelForSequenceClassification.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# Strategia di pruning: rimuovi il 30% degli heads meno importanti
# Assumendo head_importance calcolata come sopra:
# heads_to_prune = {}
# n_heads_to_prune = int(0.3 * 12 * 12) # 30% di 144 heads totali (12 layers x 12 heads)
# flat_importance = head_importance.flatten()
# _, indices = flat_importance.sort()
# for idx in indices[:n_heads_to_prune]:
# layer_idx = idx.item() // 12
# head_idx = idx.item() % 12
# if layer_idx not in heads_to_prune:
# heads_to_prune[layer_idx] = []
# heads_to_prune[layer_idx].append(head_idx)
# pruned_model = prune_transformer_attention_heads(model, heads_to_prune)
print("Movement pruning e head importance pruning: schema implementato.")
print("Risultati tipici su BERT-base con 40% pruning attention:")
print(" - Speedup inferenza: 1.3-1.5x")
print(" - Dimensione modello: -35%")
print(" - Accuratezza GLUE: -0.5 a -1.5 punti")
Lottery Ticket Hypothesis: La Teoria del Sottomodello Vincente
La Lottery Ticket Hypothesis (LTH, Frankle & Carlin, NeurIPS 2019) e una delle scoperte teoriche più influenti nel pruning: ogni rete neurale densa contiene una o più sottoreti sparse ("winning tickets") che, se estratte e re-inizializzate con i loro valori iniziali originali, possono essere addestrate da sole raggiungendo accuratezza comparabile o superiore alla rete completa, in tempi di training minori o uguali.
La LTH ha importanti implicazioni pratiche: suggerisce che il grande modello serve principalmente per trovare la struttura giusta, non per le capacità intrinseche dei suoi parametri. Il processo standard per trovare un winning ticket e l'Iterative Magnitude Pruning (IMP).
import torch
import torch.nn as nn
import copy
from typing import Dict, List
# ===================================================================
# ITERATIVE MAGNITUDE PRUNING (Lottery Ticket Hypothesis)
# ===================================================================
def save_initial_weights(model: nn.Module) -> Dict[str, torch.Tensor]:
"""Salva i pesi iniziali del modello (prima del training)."""
return {
name: param.data.clone()
for name, param in model.named_parameters()
if 'weight' in name
}
def apply_mask_and_reinit(
model: nn.Module,
initial_weights: Dict[str, torch.Tensor],
masks: Dict[str, torch.Tensor]
) -> nn.Module:
"""
Reimposta i pesi ai valori iniziali con le maschere di pruning applicate.
Questo e il passo critico della LTH: reinizializzare (non random, ma ai valori originali).
"""
with torch.no_grad():
for name, param in model.named_parameters():
if name in initial_weights and name in masks:
param.data = initial_weights[name] * masks[name]
return model
def compute_pruning_masks(
model: nn.Module,
pruning_ratio: float
) -> Dict[str, torch.Tensor]:
"""Calcola le maschere di pruning per magnitude (L1)."""
masks = {}
for name, param in model.named_parameters():
if 'weight' in name and param.dim() > 1:
# Soglia globale per layer
threshold = torch.quantile(param.abs(), pruning_ratio)
masks[name] = (param.abs() >= threshold).float()
return masks
def iterative_magnitude_pruning(
model: nn.Module,
train_fn,
eval_fn,
n_rounds: int = 5,
prune_per_round: float = 0.20,
epochs_per_round: int = 10
):
"""
Implementazione dell'Iterative Magnitude Pruning (LTH).
Algoritmo:
1. Salva i pesi iniziali (w0)
2. Addestra per N epoche
3. Pruna il P% dei pesi con magnitudine minore
4. Reinizializza i pesi sopravvissuti a w0
5. Ripeti dal passo 2
"""
# Step 1: Salva i pesi iniziali
initial_weights = save_initial_weights(model)
masks = {name: torch.ones_like(param)
for name, param in model.named_parameters()
if 'weight' in name}
cumulative_pruned = 0.0
results = []
for round_idx in range(n_rounds):
print(f"\n--- Round IMP {round_idx + 1}/{n_rounds} ---")
# Step 2: Addestra il modello (con le maschere correnti applicate)
train_fn(model, epochs=epochs_per_round, masks=masks)
# Step 3: Calcola nuove maschere di pruning
effective_prune = 1 - (1 - prune_per_round) ** (round_idx + 1)
new_masks = compute_pruning_masks(model, effective_prune)
# Step 4: Reinizializza con pesi iniziali e nuove maschere
model = apply_mask_and_reinit(model, initial_weights, new_masks)
masks = new_masks
# Valutazione
accuracy = eval_fn(model)
total_sparsity = sum(
(m == 0).float().mean().item()
for m in masks.values()
) / len(masks)
results.append({
'round': round_idx + 1,
'accuracy': accuracy,
'sparsity': total_sparsity
})
print(f"Accuratezza: {accuracy:.2%} | Sparsita: {total_sparsity:.1%}")
return model, results
# Risultati tipici IMP su ResNet-20 / CIFAR-10:
# Round 1 (20% pruned): 91.8% accuracy (baseline: 91.9%)
# Round 2 (36% pruned): 91.7% accuracy
# Round 3 (49% pruned): 91.5% accuracy
# Round 4 (59% pruned): 91.2% accuracy
# Round 5 (67% pruned): 90.8% accuracy <- "winning ticket"
# Round 8 (83% pruned): 89.1% accuracy <- accuratezza inizia a degradare
# Round 10 (89% pruned): 87.3% accuracy <- soglia tipica fine utilita
LTH in Pratica: Limitazioni
- Costo computazionale: IMP richiede molti cicli di train-prune-reinit, rendendolo costoso per modelli grandi. Per i LLM, si usano varianti più efficienti come GMP (Gradual Magnitude Pruning) che non richiedono reinizializzazione.
- Scalabilità: La LTH originale funziona su modelli piccoli. Per BERT e GPT, la reinizializzazione ai pesi iniziali non produce benefici chiari; si usa pruning + fine-tuning sui pesi correnti.
- Transfer learning: Ricerche del 2020 (Chen et al.) mostrano che i "winning tickets" di modelli pre-addestrati come BERT sono transferibili a task downstream, aprendo interessanti applicazioni.
Workflow di Pruning Iterativo con Retraining
Il workflow più efficace in produzione non e il one-shot pruning (rimuovi subito il 50% dei pesi) ma il pruning iterativo con retraining: pruna gradualmente, lasciando alla rete il tempo di "recuperare" ad ogni step. Questo produce modelli significativamente più accurati a parita di sparsita target.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from torch.optim.lr_scheduler import CosineAnnealingLR
# ===================================================================
# WORKFLOW PRUNING ITERATIVO COMPLETO
# ===================================================================
def get_global_sparsity(model: nn.Module) -> float:
"""Calcola la sparsita globale del modello."""
total_params = 0
zero_params = 0
for name, param in model.named_parameters():
if 'weight' in name:
total_params += param.numel()
zero_params += (param == 0).sum().item()
return zero_params / total_params if total_params > 0 else 0.0
def iterative_pruning_with_finetuning(
model: nn.Module,
train_loader,
val_loader,
target_sparsity: float = 0.70,
n_pruning_steps: int = 7,
finetune_epochs_per_step: int = 3,
lr: float = 1e-4,
device: str = 'cuda'
):
"""
Pruning iterativo con fine-tuning post-pruning.
Strategia: aumenta la sparsita gradualmente usando una schedule
cubica (più aggressiva all'inizio, più conservativa alla fine).
"""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
history = []
# Schedule di sparsita cubica
sparsity_schedule = [
1 - (1 - target_sparsity * (step / n_pruning_steps) ** 3)
for step in range(1, n_pruning_steps + 1)
]
print(f"Schedule sparsita: {[f'{s:.1%}' for s in sparsity_schedule]}")
for step_idx, target_sparsity_step in enumerate(sparsity_schedule):
print(f"\n=== Step {step_idx + 1}/{n_pruning_steps} | Target sparsita: {target_sparsity_step:.1%} ===")
# Raccoglie tutti i parametri weight del modello
parameters_to_prune = [
(module, 'weight')
for name, module in model.named_modules()
if isinstance(module, (nn.Linear, nn.Conv2d))
]
# Pruning globale L1
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=target_sparsity_step
)
actual_sparsity = get_global_sparsity(model)
print(f"Sparsita effettiva: {actual_sparsity:.1%}")
# Fine-tuning post-pruning
scheduler = CosineAnnealingLR(optimizer, T_max=finetune_epochs_per_step)
for epoch in range(finetune_epochs_per_step):
model.train()
train_loss = 0.0
for batch_x, batch_y in train_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
out = model(batch_x)
loss = criterion(out, batch_y)
loss.backward()
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Valutazione
model.eval()
correct = total = 0
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
pred = model(batch_x).argmax(dim=1)
correct += (pred == batch_y).sum().item()
total += batch_y.size(0)
val_acc = correct / total
history.append({'step': step_idx+1, 'sparsity': actual_sparsity, 'val_acc': val_acc})
print(f"Val accuracy: {val_acc:.2%}")
# Consolida le maschere (rende il pruning permanente)
for module, param_name in parameters_to_prune:
try:
prune.remove(module, param_name)
except ValueError:
pass # Già rimosso
return model, history
Pruning + Quantizzazione: Massima Compressione
Pruning e quantizzazione sono tecniche complementari e si combinano efficacemente. Il pruning riduce il numero di parametri; la quantizzazione riduce la precisione di ciascun parametro rimanente. Applicati insieme, producono modelli estremamente compatti. Questa combinazione e nota come "sparse quantization" o "quantized sparse models".
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# ===================================================================
# COMBINAZIONE PRUNING + QUANTIZZAZIONE
# ===================================================================
# --- Approccio 1: Pruning strutturato + Quantizzazione INT8 ---
# Pruna prima (rimuove strutture), poi quantizza il modello ridotto
def prune_and_quantize_pipeline(model_name: str, prune_ratio: float = 0.30):
"""
Pipeline: carica modello -> pruning strutturato -> quantizzazione INT8.
"""
# Step 1: Carica modello full precision
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.float32
)
print(f"Parametri originali: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
# Step 2: Pruning L1 non strutturato globale
parameters_to_prune = [
(module, 'weight')
for name, module in model.named_modules()
if isinstance(module, nn.Linear) and 'classifier' not in name
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=prune_ratio
)
# Consolida maschere
for module, param_name in parameters_to_prune:
prune.remove(module, param_name)
# Conta parametri zero
zero_params = sum(
(param == 0).sum().item()
for name, param in model.named_parameters()
if 'weight' in name
)
total_params = sum(
param.numel()
for name, param in model.named_parameters()
if 'weight' in name
)
print(f"Sparsita dopo pruning: {zero_params/total_params:.1%}")
# Step 3: Quantizzazione dinamica INT8 del modello pruned
model_quantized = torch.quantization.quantize_dynamic(
model,
{nn.Linear}, # Quantizza solo layer Linear
dtype=torch.qint8
)
return model_quantized
# --- Approccio 2: QLoRA su modello pre-pruned ---
# Per LLM: usa modelli già pruned + quantizzazione NF4 per fine-tuning
config_nf4 = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
# Molti modelli su HuggingFace Hub sono già pruned E quantizzati:
# es. "microsoft/phi-2" (2.7B), "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Questi sono modelli "distilled + pruned" during pretraining.
# --- Benchmark memoria: pruning + quantizzazione ---
compression_results = [
{"metodo": "FP32 (baseline)", "sparsita": "0%", "precisione": "FP32", "size_mb": 440},
{"metodo": "Pruning 50%", "sparsita": "50%", "precisione": "FP32", "size_mb": 220},
{"metodo": "Quantizzazione INT8", "sparsita": "0%", "precisione": "INT8", "size_mb": 110},
{"metodo": "Pruning 50% + INT8", "sparsita": "50%", "precisione": "INT8", "size_mb": 55},
{"metodo": "Pruning 70% + INT4", "sparsita": "70%", "precisione": "INT4", "size_mb": 33},
]
print(f"\n{'Metodo':<28} {'Sparsita':>10} {'Precisione':>12} {'Dimensione':>12}")
print("-" * 65)
for r in compression_results:
print(f"{r['metodo']:<28} {r['sparsita']:>10} {r['precisione']:>12} {r['size_mb']:>10} MB")
# Output (modello BERT-base ~440MB in FP32):
# Metodo Sparsita Precisione Dimensione
# FP32 (baseline) 0% FP32 440 MB
# Pruning 50% 50% FP32 220 MB
# Quantizzazione INT8 0% INT8 110 MB
# Pruning 50% + INT8 50% INT8 55 MB
# Pruning 70% + INT4 70% INT4 33 MB
Benchmark: Accuratezza, Speedup e Memoria
I risultati del pruning variano significativamente in base al modello, al task e al metodo. La seguente tabella riporta benchmark indicativi per BERT-base e ResNet-50, basati su risultati della letteratura e sperimentazioni pratiche:
| Modello | Metodo | Sparsita | Accuratezza | Speedup | Memoria |
|---|---|---|---|---|---|
| BERT-base (MNLI) | Baseline FP16 | 0% | 84.6% | 1.0x | 440 MB |
| BERT-base (MNLI) | Magnitude unstr. | 50% | 84.1% | 1.0x* | 440 MB* |
| BERT-base (MNLI) | Movement pruning | 70% | 83.5% | 1.0x* | 440 MB* |
| BERT-base (MNLI) | Head pruning 30% | 30% heads | 84.0% | 1.3x | 310 MB |
| BERT-base (SQuAD) | Block pruning str. | 50% | F1 -1% | 2.4x | 220 MB |
| ResNet-50 (ImageNet) | L1 filter pruning | 40% | Top-1 -0.5% | 1.5x | -40% |
| ResNet-50 (ImageNet) | Pruning iterativo | 70% | Top-1 -1.2% | 2.1x | -65% |
* Pruning non strutturato: nessun speedup su hardware standard senza sparse ops dedicate.
Raccomandazioni per Tipo di Hardware Target
- GPU NVIDIA standard: Preferire il pruning strutturato (Torch-Pruning, head pruning). Il pruning non strutturato non porta benefici senza supporto sparse dedicato, a meno di usare il formato 2:4 sparsity di NVIDIA Ampere (50% sparsita in pattern specifici 2 non-zero ogni 4).
- CPU (deployment inference): Il pruning non strutturato ad alta sparsita (>80%) può portare speedup con librerie come Intel oneDNN o con conversione a formato CSR/CSC. Ma il pruning strutturato rimane più predicibile.
- Edge devices (Jetson, Raspberry Pi): Pruning strutturato + quantizzazione INT8 o GGUF. La riduzione del modello e critica: anche 2x meno parametri può fare la differenza tra eseguibile e non eseguibile.
- Mobile (ARM): Usare librerie come XNNPACK o CoreML con quantizzazione INT8 e pruning strutturato per accelerazione hardware reale.
Best Practices e Anti-Pattern
Best Practices per il Pruning
- Usa pruning iterativo, non one-shot: Pruna il 10-20% per step con retraining intermedio. Un'unica rimozione aggressiva del 70% degrada quasi sempre l'accuratezza in modo irreversibile.
- Applica il retraining dopo ogni step: Anche 1-3 epoche di fine-tuning dopo ogni round di pruning recuperano la maggior parte dell'accuratezza persa. Il learning rate deve essere basso (10-100x inferiore al training originale).
- Scegli il metodo in base all'hardware target: Pruning strutturato per speedup reali su hardware standard; non strutturato solo se hai accesso a hardware sparse-capable.
- Non prunare i layer critici: Il primo e ultimo layer di ogni rete (embedding, classificatore) sono i più sensibili. Escludi o riduci fortemente il pruning su questi layer.
- Monitora la distribuzione dei pesi durante il pruning: Se troppi pesi di uno stesso layer vengono pruned (>80%), il layer potrebbe collassare. Imposta un limite minimo per layer.
- Valuta su metriche del task, non solo loss: La loss di training può non catturare degradazioni su casi edge. Usa metriche specifiche del dominio (F1, BLEU, accuratezza su test set).
Anti-Pattern da Evitare
-
Non aspettarti speedup da pruning non strutturato su GPU standard:
L'API
torch.nn.utils.prunezerizza i pesi ma non li rimuove fisicamente. Il tempo di inferenza non diminuisce senza sparse ops dedicate. -
Non mischiare maschere e pesi senza consolidare: Prima di esportare o
distribuire il modello, chiama sempre
prune.remove(module, 'weight')per consolidare la maschera nel parametro. Altrimenti il modello ha overhead di memoria e dipendenze non portabili. - Non usare un dataset di validazione troppo piccolo: Il pruning aggressivo può causare overfitting sul validation set usato per monitorare l'accuratezza. Usa un held-out test set per la valutazione finale.
- Non ignorare i layer di normalizzazione: BatchNorm e LayerNorm mantengono statistiche legate alle dimensioni dei layer precedenti. Dopo pruning strutturato, le statistiche di normalizzazione devono essere ricalibrate (re-run sul dataset di calibrazione).
- Non applicare pruning su modelli non convergenti: Il pruning funziona meglio su modelli ben addestrati. Applicarlo su un modello che non ha ancora convergito produce risultati imprevedibili.
Pruning nel 2025-2026: Stato dell'Arte
Il campo del pruning si e evoluto significativamente con l'ascesa dei LLM. Le tendenze principali nel 2025-2026 includono:
- SparseGPT e Wanda: Metodi di pruning one-shot per LLM che non richiedono retraining. SparseGPT (Frantar & Alistarh, 2023) usa l'inversa approssimata della matrice Hessiana per aggiornare i pesi rimanenti, compensando l'errore del pruning. Wanda (Sun et al., 2023) usa product of weight magnitude and input activation norms come criterio.
- 2:4 Sparsity (NVIDIA): Pattern di sparsita strutturata supportato hardware su Ampere e Hopper GPUs: esattamente 2 valori non-zero ogni 4 elementi. Produce speedup di ~1.5-2x in operazioni sparse su A100/H100 con accuratezza quasi identica al modello denso.
- CORP (2025): Closed-Form One-shot Representation-Preserving Structured Pruning per Vision Transformers — scala da DeiT-Tiny a DeiT-Huge con speedup hardware reali e minima perdita di accuratezza.
- Pruning + Distillazione: Combinare pruning con knowledge distillation (articolo precedente di questa serie) produce i risultati migliori: il modello pruned viene addestrato con supervisione del modello teacher originale.
Conclusioni
Il pruning delle reti neurali e una delle tecniche di compressione più mature e versatili nel deep learning. La comprensione della distinzione tra pruning strutturato e non strutturato e fondamentale: il primo produce speedup reali su hardware standard, il secondo richiede supporto specifico per la sparsita ma offre maggiore flessibilità.
Il pruning iterativo con retraining rimane il gold standard per la qualità dei risultati. La Lottery Ticket Hypothesis offre una visione teorica fondamentale sul perchè il pruning funziona, pur avendo limitazioni pratiche per modelli molto grandi. Per i LLM moderni, metodi come SparseGPT e Wanda offrono alternative one-shot praticabili.
La combinazione pruning + quantizzazione e la strada maestra per la massima compressione: ridurre il numero di parametri e la loro precisione numerica in modo complementare permette di ottenere modelli con footprint 10-15x inferiore al punto di partenza, mantenendo accuratezza accettabile per la maggior parte degli use case produttivi.
Prossimi Passi
- Articolo successivo: Ollama: Eseguire LLM Locali su Laptop e Raspberry
- Articolo precedente: Distillazione Modelli: Knowledge Transfer
- Correlato: Quantizzazione Modelli: GPTQ, AWQ, INT8
- Correlato: Fine-tuning con LoRA e QLoRA
- Serie MLOps: Model Serving e Deployment







