Prořezávání neuronových sítí: Snížení složitosti modelu
Model ResNet-50 má více než 25 milionů parametrů. GPT-3 má 175 miliard. Přesto výzkum systematika ukazuje, že většina těchto parametrů je redundantních: trénované neuronové sítě mohou ztratit více než 90 % své hmotnosti bez výrazné degradace přesnosti. The prořezávání — technika systematického odstraňování nadbytečných parametrů — a jeden z nejúčinnějších nástrojů pro snížení výpočetní složitosti modelů hlubokého učení.
Na rozdíl od kvantizace, která snižuje numerickou přesnost parametrů, prořezávání li úplně odstranit. Výsledkem může být menší, rychlejší a levnější model provoz je nákladný – zvláště při použití strukturované prořezávání, to odstraňuje celé neurony, filtry nebo hlavy pozornosti a vytváří skutečné zrychlení na hardwaru bez požádat o podporu pro spasita.
V této příručce prozkoumáme prořezávání do hloubky: z teorie Hypotéza loterijního lístku až po praktické implementace s PyTorchem, od prořezávání podle velikosti až po prořezávání pohybu o Transformátory, až po iterativní pracovní postupy a kombinace s kvantizací.
Co se naučíte
- Rozdíl mezi strukturovaným a nestrukturovaným prořezáváním a kdy je použít
- Velikost prořezávání: nejjednodušší a nejúčinnější metoda
- Pohybové prořezávání pro moderní transformátory a LLM
- Hypotéza loterijního lístku: teorie, která vysvětluje, proč prořezávání funguje
- PyTorch prořezávací API s kompletními příklady
- Iterativní pracovní postup prořezávání s rekvalifikací
- Torch-Pruning pro pokročilé strukturované prořezávání
- Kombinace prořezávání + kvantizace pro maximální kompresi
- Skutečná měřítka přesnosti, paměti a rychlosti
- Osvědčené postupy a běžné anti-vzorce
Proč prořezávání? Problém redundance
Moderní neuronové sítě jsou notoricky přeparametrizované. Tato redundance a částečně záměrné: větší sítě se snadněji trénují a lépe zobecňují, ale v tuto chvíli nasazení s sebou přináší zbytečnou výpočetní váhu. Tři klíčová empirická pozorování motivovat prořezávání:
- Redundance hmotnosti: Studie degenerativního ořezávání to ukazují na trénovaných sítích rozložení hmotnosti je silně koncentrováno kolem nuly. Odstraňte menší závaží velikost má minimální vliv na předpovědi.
- Hypotéza loterijního lístku (Frankle & Carlin, 2019): Každá trénovaná neuronová síť obsahuje „vítěznou“ podsíť, která je-li znovu inicializována s původními hodnotami a trénována sám o sobě dosahuje výkonu srovnatelného s kompletní sítí.
- Nadměrná parametrizace jako nástroj: Další parametry jsou pro trénink (hladší krajina, únik z místních minim), ale pro usuzování nejsou nutné.
Dopad prořezávání: Skutečná data
Výzkum na ResNet a BERT ukazuje, že modely mohou minout 70-90 % parametrů se ztrátou přesnosti menší než 1-2 %. Prořezávání strukturovaného transformátoru na základně BERT s 50% řídkostí vede ke snížení FLOPů 2x a zrychlení vyvozování z 1,5x zachování přes 99 % původní přesnosti. V kontextu LLM, Techniky blokového prořezávání pro Transformers prokázaly zrychlení až 2,4x na SQuAD pouze s 1% ztrátou F1.
Strukturované vs. Nestrukturované prořezávání
Klíčový rozdíl v prořezávání je mezi přístupy strukturovaný e nestrukturované. Výběr závisí na cílovém hardwaru a cílech nasazení:
| čekám | Nestrukturované | Strukturovaný |
|---|---|---|
| Co odstraňuje | Individuální váhy (libovolné) | Neurony, filtry, kanály, hlavy pozornosti, vrstvy |
| Výsledné šíření | Nepravidelné (řídká matrice) | Normální (malá velikost) |
| Skutečné zrychlení na standardním CPU/GPU | Žádné (bez řídkých operací) | Ano, okamžitě s hustými ops |
| Zrychlení na řídkém hardwaru (řídký CPU, Cerebras) | Si | Si |
| Snížení paměti | Pouze s explicitním řídkým formátem | Vždy (malá velikost) |
| Přesnost se stejnou vzácností | zlepšit | Mírně nižší |
| Složitost implementace | Jednoduchý | Složitější (přepočet závislostí) |
Il nestrukturované prořezávání a flexibilnější: dokáže odstranit jakoukoli váhu bez ohledu na jeho umístění. Problém je, že výsledná matrice zůstává hustá v paměti (explicitní nuly) a modernímu hardwaru neprospívá nerovnoměrná řídkost bez specifické podpory (NVIDIA zavedla podporu pro 2:4 sparity s GPU Ampere, ale vyžaduje specifické vzory). The strukturované prořezávání, odstranění struktur kompletní, produkuje ověřitelně menší modely: Lineární vrstva s 512 neurony prořezaný na 256 se jednoduše stane lineárním (in, 256), provedeným se standardními hustými operacemi.
Prořezávání velikosti: základní metoda
Il velikost prořezávání a nejjednodušší a překvapivě účinný přístup: odstranit závaží s absolutní hodnotou menší než prahová hodnota. Intuitivní logika a ta váha malé přispívají málo k signálu přenášenému sítí. Navzdory své jednoduchosti, v kombinaci s iterativní rekvalifikací přináší výsledky, které jsou konkurenceschopné s mnohem lepšími metodami sofistikované.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import numpy as np
# ===================================================================
# MAGNITUDE PRUNING CON PYTORCH NATIVE API
# ===================================================================
class ConvNet(nn.Module):
"""Modello CNN semplice per dimostrare il pruning."""
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(4)
)
self.classifier = nn.Sequential(
nn.Linear(128 * 4 * 4, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
model = ConvNet()
# --- Pruning L1 non strutturato (magnitude-based) ---
# Rimuove il 30% dei pesi con valore assoluto minore
prune.l1_unstructured(
model.features[0], # Layer da pruning
name='weight', # Parametro da prunare
amount=0.30 # Percentuale da rimuovere (30%)
)
# --- Pruning Random (baseline di confronto) ---
prune.random_unstructured(
model.features[2],
name='weight',
amount=0.30
)
# --- Analisi sparsita risultante ---
def compute_sparsity(module):
"""Calcola la sparsita effettiva di un modulo."""
total = 0
zeros = 0
for param in module.parameters():
total += param.numel()
zeros += (param == 0).sum().item()
return zeros / total if total > 0 else 0.0
print("Sparsita Conv1:", f"{compute_sparsity(model.features[0]):.1%}")
print("Sparsita Conv2:", f"{compute_sparsity(model.features[2]):.1%}")
# --- Verifica la struttura interna del pruning ---
# PyTorch crea weight_orig (originale) + weight_mask (0/1)
print("\nParametri di model.features[0] dopo pruning:")
for name, param in model.features[0].named_parameters():
print(f" {name}: shape={param.shape}")
for name, buf in model.features[0].named_buffers():
print(f" buffer {name}: shape={buf.shape}")
# --- Rimozione della maschera (make permanent) ---
# Dopo retraining, si consolida: il modello torna a usare 'weight'
prune.remove(model.features[0], 'weight')
print("\nDopo prune.remove: parametri di model.features[0]:")
for name, _ in model.features[0].named_parameters():
print(f" {name}")
# --- Global Pruning: pruning globale su tutto il modello ---
# Più efficace del pruning per-layer: usa una soglia globale
parameters_to_prune = (
(model.features[0], 'weight'),
(model.features[2], 'weight'),
(model.classifier[0], 'weight'),
(model.classifier[2], 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.40, # Rimuove 40% globalmente (non per layer)
)
# Sparsita finale per layer
for module_name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
if hasattr(module, 'weight_mask'):
sparsity = (module.weight_mask == 0).float().mean().item()
print(f"{module_name}: sparsita {sparsity:.1%}")
Upozornění: PyTorch Native Pruning neurychluje vyvozování
API torch.nn.utils.prune aplikovat jeden maska binární na vahách, nulování
vybrané, ale zachování původní husté struktury. Výsledný model zabírá
stejná paměť a předání vpřed trvá stejnou dobu. Chcete-li dosáhnout skutečného zrychlení, potřebujete:
strukturované prořezávání (s fyzickým odstraněním struktur), nebo řídce specifické knihovny
operace. Nativní prořezávání PyTorch je skvělé pro experimentování a pro QAT (Quantization-Aware
trénink) s řídkým, ale ne pro přímé nasazení.
Strukturované prořezávání s pochodňovým prořezáváním
Knihovna Prořezávání pochodní (Fang et al., CVPR 2023) řeší problém skutečné strukturované ořezávání: Odstranění filtru z vrstvy Conv2D vyžaduje také aktualizaci další vrstva (která očekává N vstupních kanálů, nikoli N-k). Rukojeti na pochodeň automaticky tyto závislosti prostřednictvím grafu závislostí (DepGraph), podporující komplexní architektury včetně ViT, LLM, YOLO a modelů s přeskočením připojení.
# pip install torch-pruning
import torch
import torch.nn as nn
import torch_pruning as tp
# ===================================================================
# PRUNING STRUTTURATO CON TORCH-PRUNING
# ===================================================================
class ResidualBlock(nn.Module):
"""Blocco residuale: Torch-Pruning gestisce la skip connection automaticamente."""
def __init__(self, channels=64):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return self.relu(out + residual) # Skip connection
class SimpleResNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.layer1 = ResidualBlock(64)
self.layer2 = ResidualBlock(64)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.pool(x).view(x.size(0), -1)
return self.fc(x)
model = SimpleResNet()
model.eval()
# Input di esempio per tracciare le dipendenze
example_input = torch.randn(1, 3, 32, 32)
# --- Costruzione del grafo delle dipendenze ---
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_input)
# --- Analisi del modello PRIMA del pruning ---
macs_before, params_before = tp.utils.count_ops_and_params(model, example_input)
print(f"Parametri PRIMA: {params_before / 1e6:.2f}M")
print(f"MACs PRIMA: {macs_before / 1e9:.3f}G")
# --- Definizione della strategia di pruning ---
# Pruning per magnitudine L1 dei filtri (L2 disponibile con tp.strategy.L2Strategy)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=example_input,
importance=tp.importance.MagnitudeImportance(p=1), # L1 norm
iterative_steps=5, # Pruning iterativo in 5 step
ch_sparsity=0.5, # Rimuove il 50% dei canali
ignored_layers=[model.fc], # Non pruning il classificatore finale
)
# --- Esecuzione del pruning (un singolo step) ---
pruner.step()
# --- Analisi del modello DOPO il pruning ---
macs_after, params_after = tp.utils.count_ops_and_params(model, example_input)
print(f"\nParametri DOPO: {params_after / 1e6:.2f}M")
print(f"MACs DOPO: {macs_after / 1e9:.3f}G")
print(f"Riduzione parametri: {(1 - params_after/params_before):.1%}")
print(f"Riduzione MACs: {(1 - macs_after/macs_before):.1%}")
# --- Verifica architettura post-pruning ---
print("\nArchitettura post-pruning:")
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
print(f" {name}: Conv2d({module.in_channels}, {module.out_channels}, ...)")
# Output tipico:
# Parametri PRIMA: 0.15M | MACs PRIMA: 0.009G
# Parametri DOPO: 0.04M | MACs DOPO: 0.003G
# Riduzione parametri: 75% | Riduzione MACs: 72%
# layer1.conv1: Conv2d(32, 32, ...) <- da 64 a 32 canali
Pohybové prořezávání pro transformátory
Prořezávání velikosti funguje dobře pro CNN, ale Transformers představují jinou výzvu: váhy pozornosti mohou mít nízkou velikost, ale být kritické pro chování model. The pohybové prořezávání (Sanh et al., 2020) řeší tento problém s radikálně odlišným přístupem: místo odstranění závaží malý, odstraní ti, kteří jsou blížící se nule během jemného ladění. Jinými slovy, kritérium a gradient hmotnosti vzhledem k cíli prořezávání, nikoli aktuální hodnota hmotnosti.
Pohybové prořezávání prokázalo významné výhody pro prořezávání modelů BERT: při vysoké řídkosti (80–97 %) překračuje prořezávání pohybu velikost prořezávání o 10–20 bodů procenta na benchmarky NLP, jako jsou MNLI a SQuAD.
# Movement Pruning per Transformer con Hugging Face + SparseML
# pip install transformers datasets sparseml
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.optim import AdamW
# ===================================================================
# MOVEMENT PRUNING MANUALE (concetto base)
# ===================================================================
class MovementPruningLinear(nn.Module):
"""
Layer Linear con movement pruning.
Mantiene uno score per ogni peso: lo score viene ottimizzato
durante il training. I pesi con score basso vengono pruned.
"""
def __init__(self, in_features, out_features, pruning_ratio=0.5):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
self.bias = nn.Parameter(torch.zeros(out_features))
# Score inizializzati a zero: durante il training salgono per i pesi importanti
self.scores = nn.Parameter(torch.zeros_like(self.weight))
self.pruning_ratio = pruning_ratio
self.mask = None
def update_mask(self):
"""Aggiorna la maschera basandosi sugli score correnti."""
k = int(self.scores.numel() * (1 - self.pruning_ratio))
# Top-k scores: mantieni i pesi con score più alto
threshold = torch.kthvalue(self.scores.flatten(), self.scores.numel() - k).values
self.mask = (self.scores >= threshold).float().detach()
def forward(self, x):
# Applica la maschera durante il forward pass
if self.mask is None:
self.update_mask()
masked_weight = self.weight * self.mask
return nn.functional.linear(x, masked_weight, self.bias)
# ===================================================================
# PRUNING PRATICO CON TRANSFORMERS + torch.nn.utils.prune
# ===================================================================
def prune_transformer_attention_heads(model, heads_to_prune):
"""
Pruna specifici attention heads da un modello BERT-like.
heads_to_prune: dict {layer_idx: [head_idx_1, head_idx_2, ...]}
"""
model.prune_heads(heads_to_prune)
return model
# Esempio: pruning degli attention heads meno importanti
# Identificazione heads da pruning (basata su Taylor importance)
def compute_head_importance(model, dataloader, device):
"""
Calcola l'importanza di ogni attention head usando Taylor expansion.
Un head e importante se rimuoverlo aumenta molto la loss.
"""
model.eval()
head_importance = torch.zeros(
model.config.num_hidden_layers,
model.config.num_attention_heads
).to(device)
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch, output_attentions=True)
loss = outputs.loss
loss.backward()
# Accumula gradienti per stimare l'importanza
for layer_idx, layer in enumerate(model.bert.encoder.layer):
attn = layer.attention.self
# Importanza approssimata: |grad * weight| sommato per head
grad = attn.value.weight.grad
weight = attn.value.weight
if grad is not None:
importance = (grad * weight).abs().view(
model.config.num_attention_heads, -1
).sum(dim=-1)
head_importance[layer_idx] += importance
return head_importance
# ===================================================================
# STRUCTURED PRUNING DI ATTENTION HEADS CON BERT
# ===================================================================
model_name = "bert-base-uncased"
# model = AutoModelForSequenceClassification.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# Strategia di pruning: rimuovi il 30% degli heads meno importanti
# Assumendo head_importance calcolata come sopra:
# heads_to_prune = {}
# n_heads_to_prune = int(0.3 * 12 * 12) # 30% di 144 heads totali (12 layers x 12 heads)
# flat_importance = head_importance.flatten()
# _, indices = flat_importance.sort()
# for idx in indices[:n_heads_to_prune]:
# layer_idx = idx.item() // 12
# head_idx = idx.item() % 12
# if layer_idx not in heads_to_prune:
# heads_to_prune[layer_idx] = []
# heads_to_prune[layer_idx].append(head_idx)
# pruned_model = prune_transformer_attention_heads(model, heads_to_prune)
print("Movement pruning e head importance pruning: schema implementato.")
print("Risultati tipici su BERT-base con 40% pruning attention:")
print(" - Speedup inferenza: 1.3-1.5x")
print(" - Dimensione modello: -35%")
print(" - Accuratezza GLUE: -0.5 a -1.5 punti")
Hypotéza loterijního lístku: Teorie vítězného podmodelu
La Hypotéza loterijního lístku (LTH, Frankle & Carlin, NeurIPS 2019) a jeden z Nejvlivnější teoretické poznatky při prořezávání: Každá hustá neuronová síť obsahuje jednu nebo více podsítí řídké („výherní tikety“), které, pokud jsou extrahovány a znovu inicializovány s původními počátečními hodnotami, mohou být trénováni samostatně a dosahovat přesnosti srovnatelné nebo lepší než kompletní síť, za kratší nebo stejný tréninkový čas.
LTH má důležité praktické důsledky: naznačuje, že velký model je primárně užitečný pro nalézt správná struktura, nikoli pro vnitřní schopnosti jejích parametrů. Standardní postup pro nalezení výherního tiketu jeIterativní prořezávání velikosti (IMP).
import torch
import torch.nn as nn
import copy
from typing import Dict, List
# ===================================================================
# ITERATIVE MAGNITUDE PRUNING (Lottery Ticket Hypothesis)
# ===================================================================
def save_initial_weights(model: nn.Module) -> Dict[str, torch.Tensor]:
"""Salva i pesi iniziali del modello (prima del training)."""
return {
name: param.data.clone()
for name, param in model.named_parameters()
if 'weight' in name
}
def apply_mask_and_reinit(
model: nn.Module,
initial_weights: Dict[str, torch.Tensor],
masks: Dict[str, torch.Tensor]
) -> nn.Module:
"""
Reimposta i pesi ai valori iniziali con le maschere di pruning applicate.
Questo e il passo critico della LTH: reinizializzare (non random, ma ai valori originali).
"""
with torch.no_grad():
for name, param in model.named_parameters():
if name in initial_weights and name in masks:
param.data = initial_weights[name] * masks[name]
return model
def compute_pruning_masks(
model: nn.Module,
pruning_ratio: float
) -> Dict[str, torch.Tensor]:
"""Calcola le maschere di pruning per magnitude (L1)."""
masks = {}
for name, param in model.named_parameters():
if 'weight' in name and param.dim() > 1:
# Soglia globale per layer
threshold = torch.quantile(param.abs(), pruning_ratio)
masks[name] = (param.abs() >= threshold).float()
return masks
def iterative_magnitude_pruning(
model: nn.Module,
train_fn,
eval_fn,
n_rounds: int = 5,
prune_per_round: float = 0.20,
epochs_per_round: int = 10
):
"""
Implementazione dell'Iterative Magnitude Pruning (LTH).
Algoritmo:
1. Salva i pesi iniziali (w0)
2. Addestra per N epoche
3. Pruna il P% dei pesi con magnitudine minore
4. Reinizializza i pesi sopravvissuti a w0
5. Ripeti dal passo 2
"""
# Step 1: Salva i pesi iniziali
initial_weights = save_initial_weights(model)
masks = {name: torch.ones_like(param)
for name, param in model.named_parameters()
if 'weight' in name}
cumulative_pruned = 0.0
results = []
for round_idx in range(n_rounds):
print(f"\n--- Round IMP {round_idx + 1}/{n_rounds} ---")
# Step 2: Addestra il modello (con le maschere correnti applicate)
train_fn(model, epochs=epochs_per_round, masks=masks)
# Step 3: Calcola nuove maschere di pruning
effective_prune = 1 - (1 - prune_per_round) ** (round_idx + 1)
new_masks = compute_pruning_masks(model, effective_prune)
# Step 4: Reinizializza con pesi iniziali e nuove maschere
model = apply_mask_and_reinit(model, initial_weights, new_masks)
masks = new_masks
# Valutazione
accuracy = eval_fn(model)
total_sparsity = sum(
(m == 0).float().mean().item()
for m in masks.values()
) / len(masks)
results.append({
'round': round_idx + 1,
'accuracy': accuracy,
'sparsity': total_sparsity
})
print(f"Accuratezza: {accuracy:.2%} | Sparsita: {total_sparsity:.1%}")
return model, results
# Risultati tipici IMP su ResNet-20 / CIFAR-10:
# Round 1 (20% pruned): 91.8% accuracy (baseline: 91.9%)
# Round 2 (36% pruned): 91.7% accuracy
# Round 3 (49% pruned): 91.5% accuracy
# Round 4 (59% pruned): 91.2% accuracy
# Round 5 (67% pruned): 90.8% accuracy <- "winning ticket"
# Round 8 (83% pruned): 89.1% accuracy <- accuratezza inizia a degradare
# Round 10 (89% pruned): 87.3% accuracy <- soglia tipica fine utilita
LTH v praxi: Omezení
- Výpočetní náklady: IMP vyžaduje mnoho cyklů trénovat-prořezávat-reinit drahé pro velké modely. Pro LLM jsou efektivnější varianty, jako je GMP (Gradual Magnitude Pruning), které nevyžadují reinicializaci.
- Škálovatelnost: Původní LTH funguje na malých modelech. Pro BERT a GPT, reinicializace na počáteční váhy nepřináší žádné jasné výhody; používá se prořezávání + jemné doladění na aktuálních vahách.
- Přenést učení: Výzkum z roku 2020 (Chen et al.) ukazuje, že „vítězství vstupenky“ předem vyškolených modelů, jako je BERT, jsou přenosné na následné úkoly, otevření zajímavé aplikace.
Iterativní pracovní postup prořezávání s přeškolením
Nejúčinnějším pracovním postupem ve výrobě není jednorázové prořezávání (okamžitě odstraňte 50 % závaží) ale iterativní prořezávání s rekvalifikací: prořezávejte postupně, ponechejte síť čas se „vzpamatovat“ v každém kroku. Vznikají tak výrazně přesnější modely vzhledem ke stejné cílové řídkosti.
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from torch.optim.lr_scheduler import CosineAnnealingLR
# ===================================================================
# WORKFLOW PRUNING ITERATIVO COMPLETO
# ===================================================================
def get_global_sparsity(model: nn.Module) -> float:
"""Calcola la sparsita globale del modello."""
total_params = 0
zero_params = 0
for name, param in model.named_parameters():
if 'weight' in name:
total_params += param.numel()
zero_params += (param == 0).sum().item()
return zero_params / total_params if total_params > 0 else 0.0
def iterative_pruning_with_finetuning(
model: nn.Module,
train_loader,
val_loader,
target_sparsity: float = 0.70,
n_pruning_steps: int = 7,
finetune_epochs_per_step: int = 3,
lr: float = 1e-4,
device: str = 'cuda'
):
"""
Pruning iterativo con fine-tuning post-pruning.
Strategia: aumenta la sparsita gradualmente usando una schedule
cubica (più aggressiva all'inizio, più conservativa alla fine).
"""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
history = []
# Schedule di sparsita cubica
sparsity_schedule = [
1 - (1 - target_sparsity * (step / n_pruning_steps) ** 3)
for step in range(1, n_pruning_steps + 1)
]
print(f"Schedule sparsita: {[f'{s:.1%}' for s in sparsity_schedule]}")
for step_idx, target_sparsity_step in enumerate(sparsity_schedule):
print(f"\n=== Step {step_idx + 1}/{n_pruning_steps} | Target sparsita: {target_sparsity_step:.1%} ===")
# Raccoglie tutti i parametri weight del modello
parameters_to_prune = [
(module, 'weight')
for name, module in model.named_modules()
if isinstance(module, (nn.Linear, nn.Conv2d))
]
# Pruning globale L1
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=target_sparsity_step
)
actual_sparsity = get_global_sparsity(model)
print(f"Sparsita effettiva: {actual_sparsity:.1%}")
# Fine-tuning post-pruning
scheduler = CosineAnnealingLR(optimizer, T_max=finetune_epochs_per_step)
for epoch in range(finetune_epochs_per_step):
model.train()
train_loss = 0.0
for batch_x, batch_y in train_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
out = model(batch_x)
loss = criterion(out, batch_y)
loss.backward()
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Valutazione
model.eval()
correct = total = 0
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
pred = model(batch_x).argmax(dim=1)
correct += (pred == batch_y).sum().item()
total += batch_y.size(0)
val_acc = correct / total
history.append({'step': step_idx+1, 'sparsity': actual_sparsity, 'val_acc': val_acc})
print(f"Val accuracy: {val_acc:.2%}")
# Consolida le maschere (rende il pruning permanente)
for module, param_name in parameters_to_prune:
try:
prune.remove(module, param_name)
except ValueError:
pass # Già rimosso
return model, history
Prořezávání + kvantizace: Maximální komprese
Prořezávání a kvantování jsou doplňkové techniky a efektivně se kombinují. Prořezávání snižuje počet parametrů; kvantizace snižuje přesnost každého z nich zbývající parametr. Při společné aplikaci vytvářejí extrémně kompaktní modely. Tato kombinace je známá jako "řídké kvantování" o "kvantované řídké modely".
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# ===================================================================
# COMBINAZIONE PRUNING + QUANTIZZAZIONE
# ===================================================================
# --- Approccio 1: Pruning strutturato + Quantizzazione INT8 ---
# Pruna prima (rimuove strutture), poi quantizza il modello ridotto
def prune_and_quantize_pipeline(model_name: str, prune_ratio: float = 0.30):
"""
Pipeline: carica modello -> pruning strutturato -> quantizzazione INT8.
"""
# Step 1: Carica modello full precision
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.float32
)
print(f"Parametri originali: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
# Step 2: Pruning L1 non strutturato globale
parameters_to_prune = [
(module, 'weight')
for name, module in model.named_modules()
if isinstance(module, nn.Linear) and 'classifier' not in name
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=prune_ratio
)
# Consolida maschere
for module, param_name in parameters_to_prune:
prune.remove(module, param_name)
# Conta parametri zero
zero_params = sum(
(param == 0).sum().item()
for name, param in model.named_parameters()
if 'weight' in name
)
total_params = sum(
param.numel()
for name, param in model.named_parameters()
if 'weight' in name
)
print(f"Sparsita dopo pruning: {zero_params/total_params:.1%}")
# Step 3: Quantizzazione dinamica INT8 del modello pruned
model_quantized = torch.quantization.quantize_dynamic(
model,
{nn.Linear}, # Quantizza solo layer Linear
dtype=torch.qint8
)
return model_quantized
# --- Approccio 2: QLoRA su modello pre-pruned ---
# Per LLM: usa modelli già pruned + quantizzazione NF4 per fine-tuning
config_nf4 = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
# Molti modelli su HuggingFace Hub sono già pruned E quantizzati:
# es. "microsoft/phi-2" (2.7B), "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Questi sono modelli "distilled + pruned" during pretraining.
# --- Benchmark memoria: pruning + quantizzazione ---
compression_results = [
{"metodo": "FP32 (baseline)", "sparsita": "0%", "precisione": "FP32", "size_mb": 440},
{"metodo": "Pruning 50%", "sparsita": "50%", "precisione": "FP32", "size_mb": 220},
{"metodo": "Quantizzazione INT8", "sparsita": "0%", "precisione": "INT8", "size_mb": 110},
{"metodo": "Pruning 50% + INT8", "sparsita": "50%", "precisione": "INT8", "size_mb": 55},
{"metodo": "Pruning 70% + INT4", "sparsita": "70%", "precisione": "INT4", "size_mb": 33},
]
print(f"\n{'Metodo':<28} {'Sparsita':>10} {'Precisione':>12} {'Dimensione':>12}")
print("-" * 65)
for r in compression_results:
print(f"{r['metodo']:<28} {r['sparsita']:>10} {r['precisione']:>12} {r['size_mb']:>10} MB")
# Output (modello BERT-base ~440MB in FP32):
# Metodo Sparsita Precisione Dimensione
# FP32 (baseline) 0% FP32 440 MB
# Pruning 50% 50% FP32 220 MB
# Quantizzazione INT8 0% INT8 110 MB
# Pruning 50% + INT8 50% INT8 55 MB
# Pruning 70% + INT4 70% INT4 33 MB
Benchmarky: Přesnost, zrychlení a paměť
Výsledky prořezávání se výrazně liší podle modelu, úkolu a metody. Následující tabulka uvádí orientační benchmarky pro BERT-base a ResNet-50 na základě literární výsledky a praktické experimenty:
| Model | Metoda | Roztroušeně | Přesnost | Zrychlení | Paměť |
|---|---|---|---|---|---|
| Na BERT (MNLI) | Základní FP16 | 0% | 84,6 % | 1,0x | 440 MB |
| Na BERT (MNLI) | Velikost unstr. | 50 % | 84,1 % | 1,0x* | 440 MB* |
| Na BERT (MNLI) | Pohybové prořezávání | 70 % | 83,5 % | 1,0x* | 440 MB* |
| Na BERT (MNLI) | Prořezávání hlavy 30 % | 30 % hlav | 84,0 % | 1,3x | 310 MB |
| Na BERT (SQuAD) | Blokové prořezávání str. | 50 % | F1 -1 % | 2,4x | 220 MB |
| ResNet-50 (ImageNet) | Prořezávání filtru L1 | 40 % | První 1 – 0,5 % | 1,5x | -40 % |
| ResNet-50 (ImageNet) | Iterativní prořezávání | 70 % | První 1 – 1,2 % | 2,1x | -65 % |
* Nestrukturované ořezávání: žádné zrychlení na standardním hardwaru bez vyhrazených řídkých operací.
Doporučení podle cílového typu hardwaru
- Standardní GPU NVIDIA: Preferujte strukturované prořezávání (Torch-Pruning, prořezávání hlavy). Nestrukturované prořezávání nemá žádný přínos bez vyhrazené řídké podpory, pokud se nepoužívá 2:4 řídký formát NVIDIA Ampere (50% řídkost ve specifických vzorcích 2 nenulové každé 4).
- CPU (odvození z nasazení): Nestrukturované prořezávání s vysokou řídkostí (>80 %) může přinést zrychlení s knihovnami jako Intel oneDNN nebo s převodem do formátu CSR/CSC. Ale strukturované prořezávání zůstává předvídatelnější.
- Zařízení Edge (Jetson, Raspberry Pi): Strukturované prořezávání + kvantizace INT8 nebo GGUF. Redukce modelu je kritická: i 2x méně parametrů může znamenat rozdíl mezi spustitelným a nespustitelným.
- Mobil (ARM): Používejte knihovny jako XNNPACK nebo CoreML s kvantizací INT8 a strukturované prořezávání pro skutečnou hardwarovou akceleraci.
Osvědčené postupy a anti-vzorce
Nejlepší postupy pro prořezávání
- Použijte iterativní prořezávání, nikoli jednorázové: Prořezávejte 10-20 % na krok s rekvalifikací střední. Jediné agresivní 70% odstranění téměř vždy výrazně sníží přesnost nevratné.
- Po každém kroku aplikujte rekvalifikaci: Dokonce i 1-3 epochy jemného doladění později každé kolo prořezávání obnoví většinu ztracené přesnosti. Míra učení musí být nízká (10-100x nižší než původní trénink).
- Vyberte metodu na základě vašeho cílového hardwaru: Strukturované prořezávání pro urychlení skutečné na standardním hardwaru; nestrukturované pouze v případě, že máte přístup k hardwaru s omezenou schopností.
- Neprořezávejte kritické vrstvy: První a poslední vrstva každé sítě (vkládání, klasifikátor) jsou nejcitlivější. Vylučte nebo výrazně omezte prořezávání na těchto vrstvách.
- Sledujte rozložení hmotnosti během prořezávání: Pokud je příliš mnoho závaží jedné stejná vrstva se ořízne (>80 %), vrstva se může zhroutit. Nastavte minimální limit na vrstvu.
- Vyhodnoťte metriky úkolů, nejen ztráty: Tréninkovou ztrátu nemusí zachytit degradace na okrajových případech. Použijte metriky specifické pro doménu (F1, BLEU, přesnost na testovací sadě).
Anti-vzory, kterým je třeba se vyhnout
-
Neočekávejte zrychlení od nestrukturovaného prořezávání na standardních GPU:
API
torch.nn.utils.prunevynuluje váhy, ale fyzicky je neodstraní. Doba inference se nezkracuje bez vyhrazených řídkých operací. -
Nemíchejte masky a závaží bez konsolidace: Před exportem popř
distribuovat model, vždy volat
prune.remove(module, 'weight')pro konsolidujte masku do parametru. Jinak má model také režii paměti nepřenosné závislosti. - Nepoužívejte příliš malou ověřovací datovou sadu: Agresivní prořezávání může způsobit nadměrné vybavení ověřovací sady používané ke sledování přesnosti. Použijte a vyložený testovací soubor pro závěrečné vyhodnocení.
- Neignorujte normalizační vrstvy: BatchNorm a LayerNorm se udržují statistiky týkající se rozměrů předchozích vrstev. Po strukturovaném prořezávání se Normalizační statistiky je třeba překalibrovat (znovu spustit na sadě kalibračních dat).
- Neaplikujte prořezávání na nekonvergované modely: Nejlépe funguje prořezávání na dobře trénovaných modelech. Aplikuje se na model, který dosud nekonvergoval výnosy nepředvídatelné výsledky.
Prořezávání v letech 2025-2026: Stav techniky
Oblast prořezávání se výrazně vyvíjela se vzestupem LLM. Hlavní trendy v letech 2025-2026 zahrnují:
- SparseGPT a Wanda: Jednorázové metody prořezávání pro LLM, které nevyžadují rekvalifikace. SparseGPT (Frantar & Alistarh, 2023) používá přibližnou inverzní hodnotu matice Hessiana, aby aktualizoval zbývající váhy a kompenzoval chybu prořezávání. Wanda (Sun a kol., 2023) používá jako kritéria součin velikosti hmotnosti a vstupních aktivačních norem.
- 2:4 Sparsity (NVIDIA): Hardwarově podporovaný strukturovaný vzor řídkosti na GPU Ampere a Hopper: přesně 2 nenulové hodnoty každé 4 prvky. Produkuje zrychlení ~1,5-2x v řídkých operacích na A100/H100 s téměř identickou přesností jako hustý model.
- CORP (2025): Uzavřená forma Jednorázová reprezentace - Zachování strukturovaného prořezávání pro Vision Transformers – měřítka od DeiT-Tiny po DeiT-Huge se skutečným a minimálním zrychlením hardwaru ztráta přesnosti.
- Prořezávání + destilace: Kombinace prořezávání se znalostní destilací (předchozí článek v této sérii) přináší nejlepší výsledky: přichází ořezaný model vyškoleni pod dohledem původního učitelského modelu.
Závěry
Prořezávání neuronové sítě je jednou z nejvyspělejších a nejuniverzálnějších kompresních technik na světě hluboké učení. Pochopení rozdílu mezi prořezáváním strukturovaný e nestrukturované a základní: první přináší skutečné zrychlení hardwaru standard, druhý vyžaduje specifickou podporu řídkosti, ale nabízí větší flexibilitu.
Il iterativní prořezávání s rekvalifikací zůstává zlatým standardem kvality výsledky. Tam Hypotéza loterijního lístku nabízí základní teoretický pohled o tom, proč prořezávání funguje, přestože má u velmi velkých modelů praktická omezení. Pro moderní LLM nabízejí metody jako SparseGPT a Wanda životaschopné jednorázové alternativy.
Kombinace prořezávání + kvantování a hlavní silnice na maximum komprese: komplementárním způsobem snížit počet parametrů a jejich numerickou přesnost umožňuje získat modely s půdorysem 10-15x menším, než je výchozí bod, při zachování Přesnost přijatelná pro většinu případů použití ve výrobě.
Další kroky
- Další článek: Ollama: Spusťte místní LLM na notebooku a Raspberry
- Předchozí článek: Destilační modely: Přenos znalostí
- Související: Kvantizační modely: GPTQ, AWQ, INT8
- Související: Jemné doladění pomocí LoRA a QLoRA
- Řada MLOps: Obsluha a nasazení modelu







