01 - Attention Mechanism nei Transformer: Guida Completa
Nel 2017, un paper di Google Brain intitolato "Attention Is All You Need" ha cambiato per sempre il campo del deep learning. Gli autori, Vaswani e colleghi, hanno proposto un'architettura interamente basata su un meccanismo chiamato attention, eliminando le reti ricorrenti (RNN) e convoluzionali che dominavano fino a quel momento. Il risultato e stata l'architettura Transformer, oggi alla base di GPT-4, Claude, Llama 3, BERT, T5, Vision Transformers e praticamente ogni modello di frontiera.
Comprendere l'attention mechanism non e un esercizio accademico: e il fondamento su cui si costruiscono tecniche come LoRA fine-tuning, quantizzazione, pruning e deployment su dispositivi edge, tutti temi che affronteremo in questa serie. Senza una comprensione solida di come funziona l'attention, ogni ottimizzazione successiva resta una scatola nera.
In questo primo articolo della serie Deep Learning Avanzato e Edge Deployment, esploreremo l'attention in profondità: dall'intuizione iniziale alla formula matematica, dall'implementazione in PyTorch alle varianti moderne come Flash Attention 3 e Grouped-Query Attention.
Panoramica della Serie
| # | Articolo | Focus |
|---|---|---|
| 1 | Sei qui - Attention Mechanism nei Transformer | Self-attention, multi-head, architettura completa |
| 2 | Fine-tuning con LoRA, QLoRA e Adapters | Parameter-efficient fine-tuning |
| 3 | Quantizzazione dei Modelli | INT8, INT4, GPTQ, AWQ |
| 4 | Pruning e Compressione | Riduzione parametri, distillazione |
| 5 | Distillazione della Conoscenza | Teacher-student, knowledge transfer |
| 6 | Ollama e LLM Locali | Inference locale, ottimizzazione |
| 7 | Vision Transformer | ViT, DINO, image classification |
| 8 | Edge Deployment | ONNX, TensorRT, dispositivi mobili |
| 9 | NAS e AutoML | Neural Architecture Search |
| 10 | Benchmark e Ottimizzazione | Profiling, metriche, tuning |
Cosa Imparerai
- perchè le RNN e le LSTM non erano sufficienti per sequenze lunghe
- L'intuizione dietro il meccanismo di attention: Query, Key e Value
- La formula completa della Scaled Dot-Product Attention
- Come funziona Multi-Head Attention e perchè servono più teste
- La differenza tra Self-Attention e Cross-Attention
- Come il Positional Encoding risolve il problema dell'ordine
- L'architettura Transformer completa: encoder e decoder
- Implementazione pratica in PyTorch, riga per riga
- Le varianti moderne: Flash Attention 3, GQA, Sliding Window Attention
- Le architetture reali: GPT (decoder-only), BERT (encoder-only), T5 (encoder-decoder)
1. Il Problema delle Sequenze: Prima dell'Attention
Per comprendere perchè l'attention e stata una rivoluzione, dobbiamo partire dai modelli che l'hanno preceduta. Il deep learning per le sequenze (testo, audio, serie temporali) era dominato da due architetture: le RNN (Recurrent Neural Networks) e le LSTM (Long Short-Term Memory).
1.1 Le RNN e il Collo di Bottiglia Sequenziale
Le RNN elaborano le sequenze un token alla volta, passando uno stato nascosto (hidden state) da un passo temporale al successivo. Ogni token aggiorna lo stato nascosto, che funge da "memoria" della sequenza vista finora.
Input: x1 -----> x2 -----> x3 -----> x4 -----> x5
| | | | |
v v v v v
Hidden: h1 -----> h2 -----> h3 -----> h4 -----> h5
| |
v v
Output: y1 y5
Problema: h5 deve "ricordare" x1 attraverso 4 passaggi.
Con sequenze di 1000+ token, l'informazione di x1 svanisce.
Questo e il problema delle dipendenze a lungo termine (long-range dependencies). In una frase come "Il gatto, che era stato adottato dal rifugio tre anni fa e che viveva felicemente con la famiglia, dormiva sul divano", la RNN deve collegare "gatto" a "dormiva" attraverso decine di token intermedi. Lo stato nascosto, compresso in un vettore di dimensione fissa, perde inevitabilmente le informazioni più vecchie.
1.2 LSTM: Un Miglioramento, Non una Soluzione
Le LSTM hanno introdotto un meccanismo di gate (input gate, forget gate, output gate) per controllare quali informazioni mantenere e quali scartare. Questo ha migliorato la situazione, ma non l'ha risolta. Le LSTM soffrono ancora di due problemi fondamentali:
Limiti delle RNN/LSTM
| Problema | Descrizione | Impatto |
|---|---|---|
| Sequenzialita | Ogni token dipende dal precedente: non si può parallelizzare | Training lentissimo su sequenze lunghe |
| Collo di bottiglia | Tutta l'informazione passa attraverso un singolo vettore | Perdita di informazione con sequenze > 100-200 token |
| Gradient vanishing | I gradienti si riducono esponenzialmente durante la backpropagation | Il modello non riesce ad apprendere relazioni distanti |
Serviva un meccanismo che permettesse a ogni token di accedere direttamente a qualsiasi altro token della sequenza, senza dover passare attraverso stati intermedi. Questo meccanismo e l'attention.
2. Cos'è l'Attention: L'Intuizione
L'attention e un meccanismo che permette a un modello di concentrare la propria attenzione sulle parti più rilevanti dell'input quando genera l'output. Invece di comprimere tutta la sequenza in un singolo vettore, l'attention crea una connessione diretta tra ogni posizione di output e tutte le posizioni di input.
Analogia: La Ricerca in una Libreria
Immagina di essere in una libreria e di cercare informazioni sulla "storia dei Transformer". Hai in mente una domanda (Query). Ogni libro ha un titolo (Key) che descrive il suo contenuto. Quando il titolo corrisponde alla tua domanda, estrai il contenuto (Value) di quel libro. L'attention funziona esattamente cosi:
- Query (Q): "Cosa sto cercando?" - la domanda che il token corrente pone
- Key (K): "Cosa contiene questo elemento?" - l'etichetta di ogni token nella sequenza
- Value (V): "Ecco l'informazione" - il contenuto effettivo di ogni token
Il meccanismo calcola un punteggio di compatibilità tra la Query e ogni Key. Questo punteggio determina quanta attenzione dare al Value corrispondente. I punteggi vengono normalizzati tramite softmax per ottenere pesi che sommano a 1, e il risultato finale e una media pesata dei Value.
Token corrente: "dormiva"
Query di "dormiva": "Chi sta compiendo questa azione?"
Key Score Peso (softmax)
"Il" -----> 0.1 0.02
"gatto" -----> 4.8 0.65 <-- Alta attenzione!
"che" -----> 0.3 0.03
"era" -----> 0.2 0.02
"stato" -----> 0.1 0.02
"adottato" -----> 1.2 0.08
"..." -----> ... ...
"sul" -----> 2.1 0.12
"divano" -----> 0.8 0.06
Output = 0.02 * V("Il") + 0.65 * V("gatto") + 0.03 * V("che") + ...
Il modello ha imparato che "gatto" e il soggetto di "dormiva",
anche se sono separati da molti token.
3. Scaled Dot-Product Attention: La Formula
La formulazione matematica dell'attention usata nei Transformer e la Scaled Dot-Product Attention. E elegante nella sua semplicità e computazionalmente efficiente grazie all'uso di operazioni matriciali.
La Formula dell'Attention
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Dove:
- Q (Query): matrice di dimensione (n x d_k), dove n e il numero di token e d_k e la dimensione delle query/key
- K (Key): matrice di dimensione (n x d_k)
- V (Value): matrice di dimensione (n x d_v), dove d_v e la dimensione dei value
- d_k: dimensione delle key, usata come fattore di scaling
- Q * K^T: prodotto scalare tra query e key (matrice n x n di score)
- / sqrt(d_k): fattore di scala per stabilizzare i gradienti
- softmax: normalizza gli score in pesi che sommano a 1
3.1 perchè lo Scaling e Necessario
Senza il fattore sqrt(d_k), il prodotto scalare tra Q e K produce valori che crescono
proporzionalmente alla dimensione d_k. Con d_k = 512, i prodotti scalari possono raggiungere
valori molto grandi. Quando questi valori finiscono nella softmax, producono distribuzioni
quasi one-hot (un peso vicino a 1, tutti gli altri vicini a 0), con gradienti estremamente piccoli.
Lo scaling previene questo problema.
Senza scaling (d_k = 512):
Score raw: [120.3, 115.8, 2.1, -5.4]
Softmax: [0.989, 0.011, 0.000, 0.000] <-- Quasi one-hot, gradienti ~0
Con scaling (/ sqrt(512) = / 22.6):
Score scaled: [5.32, 5.12, 0.09, -0.24]
Softmax: [0.44, 0.36, 0.10, 0.10] <-- Distribuzione morbida, gradienti sani
3.2 Step by Step: Calcolo dell'Attention
Vediamo un esempio numerico concreto con una sequenza di 3 token e d_k = 4:
Sequenza: ["The", "cat", "sat"]
Step 1: Genera Q, K, V tramite proiezioni lineari
Q = X * W_Q K = X * W_K V = X * W_V
Q = [[1.0, 0.5, 0.3, 0.2], (The)
[0.8, 1.2, 0.1, 0.9], (cat)
[0.3, 0.4, 1.1, 0.6]] (sat)
K = [[0.9, 0.6, 0.4, 0.1],
[0.7, 1.1, 0.2, 0.8],
[0.4, 0.3, 1.0, 0.5]]
V = [[0.2, 0.8, 0.1, 0.5],
[0.9, 0.3, 0.7, 0.2],
[0.4, 0.6, 0.5, 0.8]]
Step 2: Calcola Q * K^T (matrice 3x3 di score)
Score[i][j] = dot(Q[i], K[j])
Scores = [[1.19, 1.37, 0.89],
[1.35, 1.77, 1.10],
[0.98, 1.15, 1.42]]
Step 3: Scala per sqrt(d_k) = sqrt(4) = 2
Scaled = [[0.60, 0.69, 0.45],
[0.68, 0.89, 0.55],
[0.49, 0.58, 0.71]]
Step 4: Applica softmax per riga
Weights = [[0.33, 0.36, 0.31], (The guarda The, cat, sat)
[0.32, 0.40, 0.28], (cat guarda The, cat, sat)
[0.29, 0.32, 0.39]] (sat guarda The, cat, sat)
Step 5: Moltiplica pesi per V
Output[0] = 0.33*V[0] + 0.36*V[1] + 0.31*V[2]
= [0.51, 0.56, 0.39, 0.48]
Attenzione alla Complessità
La matrice Q * K^T ha dimensione n x n, dove n e la lunghezza della sequenza. Con n = 1000, la matrice ha 1.000.000 di elementi. Con n = 100.000, ha 10 miliardi di elementi. Questa complessità quadratica O(n^2) e il principale collo di bottiglia dei Transformer e la ragione per cui varianti come Flash Attention e Sliding Window Attention sono state sviluppate.
4. Multi-Head Attention: Guardare da Più Angolazioni
Una singola operazione di attention cattura un tipo di relazione tra i token. Ma le relazioni in una sequenza sono molteplici: relazioni sintattiche (soggetto-verbo), semantiche (sinonimi, contesto), posizionali (token adiacenti) e molte altre. La Multi-Head Attention risolve questo problema eseguendo l'attention in parallelo con diverse proiezioni.
Input X (dimensione: n x d_model, es. n x 512)
|
+---> Head 1: Q1=X*Wq1, K1=X*Wk1, V1=X*Wv1 --> Attention(Q1,K1,V1) --> Z1
| (d_k = d_model/h = 64)
+---> Head 2: Q2=X*Wq2, K2=X*Wk2, V2=X*Wv2 --> Attention(Q2,K2,V2) --> Z2
|
+---> Head 3: Q3=X*Wq3, K3=X*Wk3, V3=X*Wv3 --> Attention(Q3,K3,V3) --> Z3
|
+---> ...
|
+---> Head 8: Q8=X*Wq8, K8=X*Wk8, V8=X*Wv8 --> Attention(Q8,K8,V8) --> Z8
|
v
Concatena: [Z1; Z2; Z3; ... Z8] (dimensione: n x d_model)
|
v
Proiezione finale: Concat * W_O (dimensione: n x d_model)
Con h = 8 teste e d_model = 512, ogni testa lavora su uno
spazio di dimensione d_k = d_v = 512 / 8 = 64. Il costo computazionale totale
e simile a quello di una singola attention con dimensione piena, perchè le teste operano
in parallelo su sottospazi più piccoli.
Cosa Impara Ogni Testa
Ricerche empiriche hanno dimostrato che le diverse teste si specializzano in pattern diversi:
- Testa 1: Potrebbe imparare relazioni soggetto-verbo
- Testa 2: Potrebbe imparare relazioni di coreference (pronomi e i loro antecedenti)
- Testa 3: Potrebbe concentrarsi su token adiacenti (n-grammi locali)
- Testa 4: Potrebbe catturare relazioni a lungo raggio tra frasi
- Altre teste: Pattern sintattici, entità, struttura del discorso
Formula Multi-Head Attention
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
dove head_i = Attention(Q * W_Qi, K * W_Ki, V * W_Vi)
Parametri tipici nel paper originale: d_model = 512, h = 8, d_k = d_v = 64. Nei modelli moderni: d_model = 4096-8192, h = 32-128.
5. Self-Attention: Un Token Che Guarda Tutti Gli Altri
La self-attention e il caso specifico in cui Query, Key e Value provengono tutti dalla stessa sequenza. Ogni token genera la propria Query, Key e Value, e usa la Query per "interrogare" le Key di tutti gli altri token (incluso se stesso).
Frase: "The cat sat on the mat"
Attention Matrix (ogni riga somma a 1.0):
The cat sat on the mat
The [0.15 0.25 0.10 0.05 0.15 0.30]
cat [0.10 0.20 0.35 0.05 0.05 0.25]
sat [0.05 0.40 0.15 0.20 0.05 0.15]
on [0.05 0.10 0.30 0.10 0.15 0.30]
the [0.20 0.15 0.05 0.10 0.10 0.40]
mat [0.10 0.15 0.15 0.25 0.15 0.20]
Osservazioni:
- "sat" presta molta attenzione a "cat" (0.40) --> soggetto-verbo
- "on" presta attenzione a "sat" (0.30) e "mat" (0.30) --> relazione spaziale
- "the" (seconda occorrenza) presta molta attenzione a "mat" (0.40) --> articolo-sostantivo
La self-attention e il cuore dei Transformer. E ciò che permette al modello di costruire rappresentazioni contestuali: la rappresentazione di ogni token incorpora informazioni da tutta la sequenza, pesate in base alla rilevanza. La parola "bank" avra una rappresentazione diversa in "river bank" e "bank account" perchè i token circostanti influenzano la sua rappresentazione tramite l'attention.
Masked Self-Attention nei Decoder
Nei modelli generativi (decoder), la self-attention e mascherata: ogni token può vedere solo i token precedenti, non quelli futuri. Questo si implementa impostando a -infinito gli score dei token futuri prima della softmax, producendo pesi pari a zero. Questa e la causal attention usata in GPT, Llama e tutti i modelli autoregressive.
Mask per sequenza di 5 token (0 = visibile, -inf = mascherato):
t1 t2 t3 t4 t5
t1 [ 0 -inf -inf -inf -inf ]
t2 [ 0 0 -inf -inf -inf ]
t3 [ 0 0 0 -inf -inf ]
t4 [ 0 0 0 0 -inf ]
t5 [ 0 0 0 0 0 ]
Dopo la softmax:
t1 vede solo [t1]
t2 vede solo [t1, t2]
t3 vede solo [t1, t2, t3]
...e cosi via
6. Cross-Attention: Quando Encoder e Decoder Comunicano
La cross-attention (o encoder-decoder attention) e il meccanismo che permette al decoder di "guardare" l'output dell'encoder. A differenza della self-attention, in cui Q, K e V provengono dalla stessa sequenza, nella cross-attention le Query vengono dal decoder e le Key/Value dall'encoder.
ENCODER (processa l'input, es. frase in italiano):
"Il gatto dorme" --> Encoder --> Rappresentazioni encoder (K_enc, V_enc)
DECODER (genera l'output, es. traduzione in inglese):
"The cat" --> Self-Attention mascherata --> Q_dec
CROSS-ATTENTION:
Q = Q_dec (dal decoder: "cosa sto cercando per generare il prossimo token?")
K = K_enc (dall'encoder: "cosa contiene ogni token dell'input?")
V = V_enc (dall'encoder: "ecco le informazioni dell'input")
Il decoder può "guardare" tutta la sequenza dell'encoder
per decidere quale token generare dopo.
La cross-attention e fondamentale nelle architetture encoder-decoder usate per la traduzione automatica (T5, mBART), il riassunto di testi e la generazione condizionata. In T5, per esempio, l'encoder processa il testo di input e il decoder genera il testo di output, usando la cross-attention per consultare l'encoder ad ogni passo di generazione.
I Tre Tipi di Attention nei Transformer
| Tipo | Sorgente Q | Sorgente K, V | Dove si Usa |
|---|---|---|---|
| Self-Attention (encoder) | Encoder input | Encoder input | Encoder di BERT, T5 encoder |
| Masked Self-Attention | Decoder input | Decoder input | GPT, Llama, decoder di T5 |
| Cross-Attention | Decoder | Encoder output | Decoder di T5, mBART |
7. Positional Encoding: Come i Transformer Conoscono l'Ordine
A differenza delle RNN, che elaborano i token in ordine sequenziale, il meccanismo di self-attention e invariante rispetto all'ordine: il risultato non cambia se permuti i token in input. "Il gatto mangia il pesce" e "pesce il gatto il mangia" produrrebbero la stessa output senza un meccanismo aggiuntivo. Il positional encoding risolve questo problema aggiungendo informazione sulla posizione di ogni token.
7.1 Sinusoidal Positional Encoding (Paper Originale)
Il paper originale usa funzioni sinusoidali per generare i positional encoding:
Formule del Positional Encoding Sinusoidale
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Dove pos e la posizione del token nella sequenza e i e la dimensione. Le posizioni pari usano il seno, le dispari il coseno. La diversa frequenza per ogni dimensione permette al modello di apprendere relazioni posizionali relative.
Posizione 0: [sin(0), cos(0), sin(0), cos(0), ...] = [0.00, 1.00, 0.00, 1.00, ...]
Posizione 1: [sin(1), cos(1), sin(0.01), cos(0.01)] = [0.84, 0.54, 0.01, 1.00, ...]
Posizione 2: [sin(2), cos(2), sin(0.02), cos(0.02)] = [0.91, -0.42, 0.02, 1.00, ...]
L'embedding finale di ogni token e:
token_embedding = word_embedding + positional_encoding
Le frequenze più basse (dimensioni alte) catturano posizioni globali.
Le frequenze più alte (dimensioni basse) catturano posizioni locali.
7.2 Learned Positional Encoding
Un'alternativa al positional encoding sinusoidale e l'uso di embedding appresi (learned): una matrice di parametri addestrabili, una riga per ogni posizione. Questo approccio e usato in BERT e GPT-2. Il vantaggio e che il modello può apprendere pattern posizionali ottimali per il task specifico. Lo svantaggio e che la lunghezza massima della sequenza e fissata al training.
Confronto Positional Encoding
| Tipo | Vantaggi | Svantaggi | Usato in |
|---|---|---|---|
| Sinusoidale | Nessun parametro extra, generalizza a sequenze più lunghe | Pattern fissi, non ottimizzati per il task | Transformer originale |
| Learned | Ottimizzato per il task specifico | Lunghezza massima fissa, più parametri | BERT, GPT-2 |
| RoPE (Rotary) | Cattura posizioni relative, estendibile | Maggiore complessità implementativa | Llama, Mistral, GPT-NeoX |
| ALiBi | Nessun parametro, buona extrapolation | Bias lineare può essere limitante | BLOOM, MPT |
8. L'Architettura Transformer Completa
Con tutti i pezzi del puzzle in mano, possiamo ora assemblare l'architettura Transformer completa. Il Transformer originale e composto da un encoder stack e un decoder stack, ciascuno formato da N layer identici (N = 6 nel paper originale).
INPUT EMBEDDING + POSITIONAL ENCODING
|
+---------v-----------+
| ENCODER STACK | x N (6 nel paper originale)
| |
| +--Multi-Head-------+
| | Self-Attention |
| +------|------------+
| v
| +--Add & Norm-------+ (residual connection + layer norm)
| +------|------------+
| v
| +--Feed-Forward-----+ (2 layer lineari con ReLU/GELU)
| | Network | (d_model -> d_ff -> d_model)
| +------|------------+ (d_ff = 4 * d_model = 2048)
| v
| +--Add & Norm-------+
| +------|------------+
+---------|-----------+
|
| (K, V per cross-attention)
|
OUTPUT EMBEDDING + POSITIONAL ENCODING
|
+---------v-----------+
| DECODER STACK | x N
| |
| +--Masked Multi-----+
| | Head Self-Attn | (causal mask: vede solo il passato)
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
| v
| +--Cross-Attention--+ (Q dal decoder, K/V dall'encoder)
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
| v
| +--Feed-Forward-----+
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
+---------|-----------+
|
v
Linear + Softmax
|
v
Output Probabilities (vocabulario)
8.1 Residual Connections
Ogni sub-layer (attention o feed-forward) ha una residual connection:
l'output del sub-layer viene sommato all'input. La formula e
output = LayerNorm(x + SubLayer(x)). Le residual connection risolvono il
problema del vanishing gradient nelle reti profonde, permettendo ai gradienti di fluire
direttamente attraverso le connessioni di scorciatoia.
8.2 Feed-Forward Network
Dopo l'attention, ogni token passa attraverso una rete feed-forward applicata indipendentemente a ogni posizione. E composta da due trasformazioni lineari con un'attivazione non lineare (ReLU nel paper originale, GELU o SwiGLU nei modelli moderni):
FFN(x) = W2 * activation(W1 * x + b1) + b2
La dimensione interna (d_ff) e tipicamente 4 volte d_model. Con d_model = 512, d_ff = 2048. Nei modelli moderni come Llama 3, d_ff arriva a 14.336 con d_model = 4096.
8.3 Layer Normalization
La Layer Normalization normalizza le attivazioni lungo la dimensione delle feature (non del batch). Stabilizza il training e accelera la convergenza. Nel Transformer originale si usa Post-LN (normalizzazione dopo la residual connection), ma la maggior parte dei modelli moderni usa Pre-LN (normalizzazione prima del sub-layer), che e più stabile durante il training.
9. Implementazione PyTorch: Self-Attention da Zero
Passiamo dalla teoria al codice. Implementeremo la Scaled Dot-Product Attention e la Multi-Head Attention da zero in PyTorch, senza usare moduli pre-costruiti.
9.1 Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
dropout: nn.Dropout = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Scaled Dot-Product Attention.
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: (batch, 1, 1, seq_len) o (batch, 1, seq_len, seq_len)
dropout: modulo dropout opzionale
Returns:
output: (batch, heads, seq_len, d_v)
attention_weights: (batch, heads, seq_len, seq_len)
"""
d_k = query.size(-1)
# Step 1: Calcola gli score Q * K^T / sqrt(d_k)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: Applica la maschera (opzionale)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax per ottenere i pesi di attention
attention_weights = F.softmax(scores, dim=-1)
# Step 4: Dropout opzionale sui pesi
if dropout is not None:
attention_weights = dropout(attention_weights)
# Step 5: Moltiplica pesi per Value
output = torch.matmul(attention_weights, value)
return output, attention_weights
9.2 Multi-Head Attention
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention implementata da zero.
Parametri:
d_model: dimensione del modello (es. 512)
num_heads: numero di teste di attention (es. 8)
dropout: tasso di dropout (es. 0.1)
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, \
f"d_model ({d_model}) deve essere divisibile per num_heads ({num_heads})"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # dimensione per testa
# Proiezioni lineari per Q, K, V e output
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Riorganizza il tensore da (batch, seq_len, d_model)
a (batch, num_heads, seq_len, d_k).
"""
batch_size, seq_len, _ = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (batch, heads, seq_len, d_k)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None
) -> torch.Tensor:
"""
Forward pass.
Per Self-Attention: query = key = value = X
Per Cross-Attention: query = decoder, key = value = encoder
"""
batch_size = query.size(0)
# 1. Proiezioni lineari
q = self.w_q(query) # (batch, seq_len, d_model)
k = self.w_k(key)
v = self.w_v(value)
# 2. Dividi in teste
q = self.split_heads(q) # (batch, heads, seq_len, d_k)
k = self.split_heads(k)
v = self.split_heads(v)
# 3. Scaled Dot-Product Attention
attn_output, attn_weights = scaled_dot_product_attention(
q, k, v, mask=mask, dropout=self.dropout
)
# 4. Concatena le teste
# (batch, heads, seq_len, d_k) -> (batch, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# 5. Proiezione finale
output = self.w_o(attn_output)
return output
9.3 Esempio d'Uso
# Configurazione
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8
# Crea il modulo
mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
# Input random (simula una sequenza di token embeddings)
x = torch.randn(batch_size, seq_len, d_model)
# Self-Attention (query = key = value)
output = mha(query=x, key=x, value=x)
print(f"Input shape: {x.shape}") # torch.Size([2, 10, 512])
print(f"Output shape: {output.shape}") # torch.Size([2, 10, 512])
# Causal mask per decoder (triangolare inferiore)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
# Masked Self-Attention
output_masked = mha(query=x, key=x, value=x, mask=causal_mask)
print(f"Masked output shape: {output_masked.shape}")
# Cross-Attention (query dal decoder, key/value dall'encoder)
encoder_output = torch.randn(batch_size, 20, d_model) # sequenza encoder più lunga
decoder_input = torch.randn(batch_size, seq_len, d_model)
cross_attn_output = mha(
query=decoder_input,
key=encoder_output,
value=encoder_output
)
print(f"Cross-attention shape: {cross_attn_output.shape}") # [2, 10, 512]
10. Varianti Moderne di Attention
La complessità quadratica O(n^2) della standard attention ha motivato lo sviluppo di numerose varianti ottimizzate. Queste varianti sono fondamentali per i modelli moderni che gestiscono contesti da 100K a oltre 1 milione di token.
10.1 Flash Attention (v1, v2, v3)
Flash Attention, sviluppata da Tri Dao e colleghi, non cambia la matematica dell'attention ma ne ottimizza radicalmente l'implementazione a livello hardware. L'idea chiave e evitare di materializzare la matrice completa n x n degli attention score in memoria GPU (HBM), usando invece un approccio tiled che lavora interamente nella SRAM (memoria veloce on-chip).
Evoluzione di Flash Attention
| Versione | Anno | Innovazione Chiave | Performance |
|---|---|---|---|
| Flash Attention 1 | 2022 | Tiling + fused kernel, IO-awareness | 2-4x speedup vs standard |
| Flash Attention 2 | 2023 | Parallelismo migliorato, meno comunicazione | 2x ulteriore rispetto a v1 |
| Flash Attention 3 | 2024 | Asincronia su Hopper GPU, FP8, warp specialization | Fino a 740 TFLOPS (FP16) su H100, 1.2 PFLOPS con FP8 |
Flash Attention 3 sfrutta le caratteristiche specifiche delle GPU NVIDIA Hopper (H100/H200): asincronia tra Tensor Core e TMA (Tensor Memory Accelerator) per sovrapporre calcolo e trasferimento dati, warp specialization per interleaving ottimale delle operazioni matmul e softmax, e quantizzazione a blocchi FP8 con errore numerico 2.6 volte inferiore rispetto a un'implementazione FP8 naive. Flash Attention e ora integrata in PyTorch, Hugging Face Transformers, vLLM e TensorRT-LLM.
10.2 Multi-Query Attention (MQA)
Proposta da Shazeer nel 2019, la Multi-Query Attention riduce drasticamente la memoria necessaria per la KV cache durante l'inference. Invece di avere un set separato di Key e Value per ogni testa, MQA condivide un singolo set di K e V tra tutte le teste, mantenendo Query diverse.
Multi-Head Attention (MHA) - Standard:
Head 1: Q1, K1, V1 | KV Cache per head: d_k * seq_len * 2
Head 2: Q2, K2, V2 | KV Cache totale: h * d_k * seq_len * 2
... | Con h=32, d_k=128, seq=4096:
Head h: Qh, Kh, Vh | = 32 * 128 * 4096 * 2 = 33.5 MB per layer
Multi-Query Attention (MQA):
Head 1: Q1 \
Head 2: Q2 |--- K_shared, V_shared
... | KV Cache totale: d_k * seq_len * 2
Head h: Qh / = 128 * 4096 * 2 = 1.05 MB per layer (32x meno!)
10.3 Grouped-Query Attention (GQA)
GQA, introdotta da Ainslie et al. nel 2023, e un compromesso tra MHA e MQA. Invece di condividere un singolo set di K/V tra tutte le teste (MQA) o averne uno per ogni testa (MHA), GQA raggruppa le teste in g gruppi, con ogni gruppo che condivide un set di K/V. Con g = 1 si ottiene MQA, con g = h si ottiene MHA.
Esempio: 8 query heads, 2 KV groups (g=2)
Gruppo 1: Q1, Q2, Q3, Q4 condividono K1, V1
Gruppo 2: Q5, Q6, Q7, Q8 condividono K2, V2
KV Cache: g * d_k * seq_len * 2 = 2 * 128 * 4096 * 2 = 2.1 MB
(16x meno di MHA, ma solo 2x più di MQA)
Modelli che usano GQA:
- Llama 2 (70B): 8 KV heads, 64 query heads
- Llama 3: GQA con rapporto 8:1
- Mistral 7B: 8 KV heads, 32 query heads
Confronto Varianti di Attention
| Variante | KV Heads | Memoria KV Cache | qualità | Modelli |
|---|---|---|---|---|
| MHA | h (tutte) | Massima | Migliore | BERT, GPT-2, GPT-3 |
| GQA | g (gruppi) | h/g riduzione | Quasi pari a MHA | Llama 2/3, Mistral |
| MQA | 1 | Minima | Lieve calo | PaLM, Falcon |
10.4 Sliding Window Attention
La Sliding Window Attention, usata in Mistral e Longformer, limita l'attention a una finestra locale di w token per ogni posizione. Invece di calcolare l'attention su tutta la sequenza (O(n^2)), ogni token vede solo i w token precedenti, riducendo la complessità a O(n * w).
Sequenza: t1 t2 t3 t4 t5 t6 t7 t8
Attention di t5 (window=3): vede solo [t3, t4, t5]
Attention di t8 (window=3): vede solo [t6, t7, t8]
Attention Matrix (1 = visibile, 0 = mascherato):
t1 t2 t3 t4 t5 t6 t7 t8
t1 [ 1 0 0 0 0 0 0 0 ]
t2 [ 1 1 0 0 0 0 0 0 ]
t3 [ 1 1 1 0 0 0 0 0 ]
t4 [ 0 1 1 1 0 0 0 0 ]
t5 [ 0 0 1 1 1 0 0 0 ]
t6 [ 0 0 0 1 1 1 0 0 ]
t7 [ 0 0 0 0 1 1 1 0 ]
t8 [ 0 0 0 0 0 1 1 1 ]
L'informazione NON si perde: attraverso più layer stacked,
l'informazione di t1 può raggiungere t8 per propagazione.
Con L layer e window w, la reception field effettiva e L * w.
10.5 Ring Attention e PagedAttention
Per contesti lunghissimi (oltre 1 milione di token), sono emerse ulteriori innovazioni:
- Ring Attention: distribuisce il calcolo dell'attention su più GPU organizzate ad anello. Ogni GPU calcola l'attention su un segmento della sequenza e passa i risultati alla GPU successiva. RingX (2025) raggiunge efficienza del 94% fino a 4096 GPU con sequenze da 1 milione di token.
- PagedAttention: ispirata alla gestione della memoria virtuale dei sistemi operativi, alloca la KV cache in blocchi (pagine) non contigue, eliminando la frammentazione di memoria. E alla base di vLLM e permette batch size fino a 76 volte superiori.
- FlexAttention (PyTorch): un'API unificata che supporta diverse varianti di attention (GQA, causal, sliding window, PagedAttention) con meno del 5% di overhead rispetto a implementazioni dedicate.
11. Applicazioni: Le Architetture Transformer nella Pratica
L'architettura Transformer ha dato vita a tre famiglie principali di modelli, ciascuna che utilizza l'attention in modo diverso.
11.1 Encoder-Only: BERT e Derivati
I modelli encoder-only usano self-attention bidirezionale: ogni token può vedere tutti gli altri token nella sequenza, sia quelli precedenti che quelli successivi. Questo li rende ideali per task di comprensione del linguaggio.
BERT (Bidirectional Encoder Representations from Transformers)
- Pre-training: Masked Language Model (MLM) + Next Sentence Prediction
- Attention: Self-attention bidirezionale (vede tutta la sequenza)
- Task: Classificazione, Named Entity Recognition, Question Answering
- Varianti: RoBERTa, ALBERT, DeBERTa, DistilBERT
11.2 Decoder-Only: GPT e la Famiglia di LLM
I modelli decoder-only usano masked self-attention (causal): ogni token vede solo i token precedenti. Sono ottimizzati per la generazione autoregressiva di testo.
Modelli Decoder-Only
| Modello | Parametri | Variante Attention | Context Window |
|---|---|---|---|
| GPT-3 | 175B | MHA standard | 2K-4K token |
| GPT-4 | ~1.8T (MoE) | GQA (stimato) | 128K token |
| Llama 3 405B | 405B | GQA + RoPE | 128K token |
| Mistral 7B | 7.3B | GQA + Sliding Window | 32K token |
| Claude (Anthropic) | Non pubblicato | Non pubblicato | 200K token |
11.3 Encoder-Decoder: T5 e Modelli Seq2Seq
I modelli encoder-decoder usano tutti e tre i tipi di attention: self-attention bidirezionale nell'encoder, masked self-attention nel decoder e cross-attention tra decoder e encoder. Sono ideali per task che trasformano un input in un output (traduzione, riassunto, question answering).
Modelli Encoder-Decoder
- T5: "Text-to-Text Transfer Transformer" - ogni task e formulato come testo-in-testo-out
- BART: Denoising autoencoder per generazione e comprensione
- mBART: BART multilingue per traduzione
- Flan-T5: T5 istruito con instruction tuning
11.4 Vision Transformer (ViT)
L'attention non si limita al testo. I Vision Transformer applicano la self-attention alle immagini, dividendo l'immagine in patch (es. 16x16 pixel) e trattando ogni patch come un "token". Questo ha dimostrato che l'attention e un meccanismo generale applicabile a qualsiasi tipo di dato sequenziale.
Immagine 224x224 pixel
|
v
Dividi in patch 16x16: (224/16)^2 = 196 patch
|
v
Ogni patch -> flatten -> proiezione lineare -> patch embedding
|
v
[CLS] + 196 patch embeddings + positional encoding
|
v
Transformer Encoder (self-attention su 197 token)
|
v
[CLS] token -> classificazione dell'immagine
Conclusioni e Prossimi Passi
In questo articolo abbiamo percorso l'intero arco dell'attention mechanism: dal problema delle dipendenze a lungo termine nelle RNN, all'intuizione di Query-Key-Value, alla formula della Scaled Dot-Product Attention, alla Multi-Head Attention, fino all'architettura Transformer completa. Abbiamo implementato la self-attention da zero in PyTorch e esplorato le varianti moderne che rendono possibili i modelli con milioni di token di contesto.
L'attention e il mattone fondamentale su cui si costruisce tutto il deep learning moderno. Comprendere come funziona ti permette di capire perchè alcune ottimizzazioni funzionano, perchè certi modelli sono più veloci di altri e come scegliere l'architettura giusta per il tuo caso d'uso.
Concetti Chiave da Ricordare
- Attention permette connessioni dirette tra qualsiasi coppia di token, senza colli di bottiglia
- Scaling (sqrt(d_k)) previene gradienti instabili nella softmax
- Multi-Head cattura relazioni diverse in parallelo senza costo aggiuntivo
- Self-Attention crea rappresentazioni contestuali; Cross-Attention collega encoder e decoder
- Positional Encoding fornisce informazione sull'ordine (sinusoidale, learned, RoPE)
- Flash Attention ottimizza l'implementazione hardware senza cambiare la matematica
- GQA e il compromesso ottimale tra qualità (MHA) e efficienza (MQA)
Nel prossimo articolo della serie, esploreremo il fine-tuning dei Transformer con LoRA, QLoRA e Adapters: come adattare modelli pre-addestrati a task specifici modificando solo una piccola frazione dei parametri, riducendo drasticamente i costi di GPU e memoria.
Risorse Aggiuntive
- Paper originale: "Attention Is All You Need" (Vaswani et al., 2017)
- Flash Attention 3: "Fast and Accurate Attention with Asynchrony and Low-precision" (Dao et al., 2024)
- GQA Paper: "GQA: Training Generalized Multi-Query Transformer Models" (Ainslie et al., 2023)
- The Illustrated Transformer: Guida visuale di Jay Alammar
- PyTorch Documentation: torch.nn.MultiheadAttention per implementazioni ottimizzate
- Hugging Face: Documentazione dei Transformer con esempi pratici







