Vision Transformer (ViT): Arhitectură și aplicații practice
În 2020, o lucrare Google Research a schimbat radical viziunea computerizată: „An Image is Worth Cuvinte 16x16". Intuiția a fost simplă, dar revoluționară - aplicați arhitectura Transformer, dominantă în NLP, direct la imagini prin tratarea i petice vizuale ca semne. Rezultatul a fost Transformator de vedere (ViT), care a depășit CNN în câțiva ani de ultimă generație pe ImageNet și zeci de alte repere, deschizând calea pentru un nou generarea de modele vizuale.
Promisiunea ViT nu este doar acuratețe: este versatilitate. Aceeași coloană vertebrală Transformerul folosit pentru text poate fi partajat cu imagini, permițând șabloane multimoduri precum CLIP, DALL-E și GPT-4V. ViT-urile se scalează mai bine decât CNN-urile cu date și calcul suplimentare, și variante precum Swin Transformer e DeiT au făcut aceste modele eficiente chiar și pe seturi de date de dimensiuni medii fără antrenament prealabil pe sute de milioane de imagini.
În acest ghid, construim un ViT de la zero în PyTorch, explorăm variațiile arhitecturale cel mai important și vă arătăm cum să ajustați pentru sarcini specifice de producție.
Ce vei învăța
- Arhitectura ViT: încorporare de patch-uri, codificare pozițională, atenție vizuală
- Implementare completă de la zero cu PyTorch
- Diferențele dintre ViT-B/16, ViT-L/32, DeiT, Swin Transformer
- Reglarea fină a ViT pre-instruit pe seturi de date personalizate
- Tehnici de creștere a datelor pentru ViT (MixUp, CutMix, RandAugment)
- Desfășurarea atenției și interpretabilitatea hărților de atenție
- Implementare optimizată: ONNX, TorchScript, dispozitive edge
- Benchmark ViT vs CNN pe seturi de date reale
Arhitectura ViT: Cum funcționează
Vision Transformer preia o imagine ca intrare și o împarte în patch-uri care nu se suprapun
dimensiune fixă (de obicei 16x16 sau 32x32 pixeli). Fiecare plasture vine turtit
(aplatiza) și proiectat liniar într-un vector de dimensiune d_model (înglobarea).
Aceste înglobări, numite înglobări de plasturi, devin jetoanele Transformer.
Un simbol special [CLS] (jeton de clasă) este prefixat la secvență, în mod similar
la BERT în NLP. Odată ce codificarea este completă, reprezentarea jetonului CLS este transmisă
un cap de clasificare pentru a produce predicția finală. Codificare pozițională - în formă
învățat în loc de sine — se adaugă înglobărilor pentru a păstra informațiile
spațiu care s-ar pierde fără el.
# Diagramma architettura ViT
#
# Input Image (224x224x3)
# |
# v
# Patch Extraction: divide in 196 patch di 16x16
# (224/16 = 14 patch per lato -> 14*14 = 196 patch)
# |
# v
# Patch Embedding: ogni patch [768] via Linear projection
# + [CLS] token -> sequenza di 197 token
# |
# v
# + Positional Embedding (learnable, 197x768)
# |
# v
# Transformer Encoder (L strati):
# - LayerNorm
# - Multi-Head Self-Attention (h heads)
# - Residual connection
# - LayerNorm
# - MLP (d_model -> 4*d_model -> d_model)
# - Residual connection
# |
# v
# [CLS] token representation
# |
# v
# MLP Head -> num_classes output
# Varianti standard:
# ViT-B/16: d_model=768, L=12, h=12 | ~86M param
# ViT-L/16: d_model=1024, L=24, h=16 | ~307M param
# ViT-H/14: d_model=1280, L=32, h=16 | ~632M param
Implementarea ViT de la zero
Să construim un ViT complet în PyTorch. Să începem de la componenta fundamentală: the Încorporarea patch-urilor.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import math
# ============================================================
# 1. PATCH EMBEDDING
# ============================================================
class PatchEmbedding(nn.Module):
"""
Converte un'immagine in una sequenza di patch embedding.
Metodo 1: Convolution (efficiente, equivalente a patch+linear)
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Equivalente a: flatten ogni patch + proiezione lineare
# Ma implementato come Conv2d per efficienza
self.projection = nn.Sequential(
# Divide in patch e proietta
nn.Conv2d(in_channels, d_model,
kernel_size=patch_size, stride=patch_size),
# [B, d_model, H/P, W/P] -> [B, n_patches, d_model]
Rearrange('b d h w -> b (h w) d')
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, H, W]
return self.projection(x) # [B, n_patches, d_model]
# ============================================================
# 2. MULTI-HEAD SELF ATTENTION per ViT
# ============================================================
class ViTAttention(nn.Module):
"""Multi-head self-attention con dropout."""
def __init__(self, d_model=768, n_heads=12, attn_dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim ** -0.5
# QKV projection
self.qkv = nn.Linear(d_model, d_model * 3, bias=True)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
B, N, C = x.shape
# Calcola Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Ognuno: [B, heads, N, head_dim]
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
attn_weights = attn # Salva per attention rollout
attn = self.attn_drop(attn)
# Weighted sum + proiezione output
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, attn_weights
# ============================================================
# 3. TRANSFORMER ENCODER BLOCK
# ============================================================
class ViTBlock(nn.Module):
"""Singolo blocco Transformer per ViT."""
def __init__(self, d_model=768, n_heads=12, mlp_ratio=4.0,
dropout=0.0, attn_dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = ViTAttention(d_model, n_heads, attn_dropout)
self.norm2 = nn.LayerNorm(d_model)
mlp_dim = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor):
# Pre-norm + residual connection
attn_out, attn_weights = self.attn(self.norm1(x))
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x, attn_weights
# ============================================================
# 4. VISION TRANSFORMER COMPLETO
# ============================================================
class VisionTransformer(nn.Module):
"""
Vision Transformer (ViT) completo.
Paper: "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2020)
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
num_classes: int = 1000,
d_model: int = 768,
depth: int = 12,
n_heads: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
attn_dropout: float = 0.0,
representation_size: int = None # Pre-logit layer (opzionale)
):
super().__init__()
self.num_classes = num_classes
self.d_model = d_model
# Patch + Position Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, d_model)
n_patches = self.patch_embed.n_patches
# Token CLS e positional embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
self.pos_drop = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
ViTBlock(d_model, n_heads, mlp_ratio, dropout, attn_dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(d_model)
# Classification head
if representation_size is not None:
self.pre_logits = nn.Sequential(
nn.Linear(d_model, representation_size),
nn.Tanh()
)
else:
self.pre_logits = nn.Identity()
self.head = nn.Linear(
representation_size if representation_size else d_model,
num_classes
)
# Inizializzazione pesi
self._init_weights()
def _init_weights(self):
"""Inizializzazione standard per ViT."""
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor, return_attn: bool = False):
B = x.shape[0]
# 1. Patch embedding
x = self.patch_embed(x) # [B, n_patches, d_model]
# 2. Prepend CLS token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
x = torch.cat([cls_tokens, x], dim=1) # [B, n_patches+1, d_model]
# 3. Add positional embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# 4. Transformer blocks
attn_weights_list = []
for block in self.blocks:
x, attn_weights = block(x)
attn_weights_list.append(attn_weights)
# 5. Layer norm finale
x = self.norm(x)
# 6. Usa solo il CLS token per classificazione
cls_output = x[:, 0]
cls_output = self.pre_logits(cls_output)
logits = self.head(cls_output)
if return_attn:
return logits, attn_weights_list
return logits
# ============================================================
# 5. CREAZIONE VARIANTI STANDARD
# ============================================================
def vit_base_16(num_classes=1000, **kwargs):
"""ViT-B/16: 86M parametri, input 224x224."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=768,
depth=12, n_heads=12, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_large_16(num_classes=1000, **kwargs):
"""ViT-L/16: 307M parametri, input 224x224."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=1024,
depth=24, n_heads=16, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_tiny_16(num_classes=1000, **kwargs):
"""ViT-Ti/16: ~6M parametri, per edge/mobile."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=192,
depth=12, n_heads=3, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
# Test
model = vit_base_16(num_classes=100)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}") # [2, 100]
print(f"Parametri: {sum(p.numel() for p in model.parameters()):,}")
# Parametri: 85,880,164
Variante arhitecturale: DeiT, Swin și BEiT
ViT original a necesitat cantități uriașe de date (JFT-300M, 300 de milioane de imagini) pentru depășesc CNN-urile. Această limitare a împins dezvoltarea unor variante mai eficiente din punct de vedere al datelor:
| Model | An | Inovație cheie | ImageNet Top-1 | Parametrii |
|---|---|---|---|---|
| ViT-B/16 | 2020 | Primul ViT, necesită JFT-300M | 81,8% | 86M |
| DeiT-B | 2021 | Distilare de la profesorul CNN, doar ImageNet | 83,1% | 87M |
| Swin-B | 2021 | Fereastra deplasată Atenție, ierarhic | 85,2% | 88M |
| BEiT-L | 2022 | Modelarea imaginilor mascate (BERT pentru viziune) | 87,4% | 307M |
| DeiT III-H | 2022 | Rețetă de antrenament avansat | 87,7% | 632M |
| ViT-G (EVA) | 2023 | Scalează la parametrul 1B, pre-antrenament CLIP | 89,6% | 1.0B |
DeiT (transformatoare de imagine eficiente pentru date) de Facebook AI și probabil varianta cel mai important pentru practică: introduce jeton de distilare care vă permite să învățați de la un profesor CNN (cum ar fi RegNet sau ConvNext), obținând performanțe excelente doar cu ImageNet-1K.
Swin Transformer rezolvă problema complexității pătratice a atenției introducerea Schimbat Windows: atenția este calculată în ferestrele locale mai degrabă decât pe întreaga imagine, cu un cost de calcul liniar în raport cu imaginea. Swin produce reprezentări ierarhice (cum ar fi CNN) și este coloana vertebrală preferată pentru detectare și segmentare.
Reglaj fin ViT Pre-antrenat
Cel mai practic mod de a utiliza ViT-urile în producție este să porniți de la un model pre-antrenat ImageNet-21K și faceți reglaj fin pe setul dvs. de date. Hugging Face Transformers oferă toate Modele ViT de bază cu API uniformă.
# pip install transformers timm torch torchvision datasets
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
ViTForImageClassification, ViTImageProcessor,
AutoImageProcessor
)
from torchvision import datasets, transforms
import os
# ============================================================
# FINE-TUNING ViT-B/16 su Dataset Custom
# ============================================================
class ViTFineTuner(nn.Module):
"""
ViT pre-addestrato con classification head custom.
Supporta fine-tuning parziale o completo.
"""
def __init__(self, num_classes: int, model_name: str = "google/vit-base-patch16-224",
freeze_backbone: bool = False):
super().__init__()
# Carica ViT pre-addestrato da HuggingFace
self.vit = ViTForImageClassification.from_pretrained(
model_name,
num_labels=num_classes,
ignore_mismatched_sizes=True # Permette cambio num_classes
)
if freeze_backbone:
# Congela tutto tranne il classification head
for param in self.vit.vit.parameters():
param.requires_grad = False
# Solo il classifier rimane trainable
print(f"Parametri trainabili: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
def forward(self, x):
outputs = self.vit(pixel_values=x)
return outputs.logits
# ============================================================
# DATA AUGMENTATION per ViT
# ============================================================
def get_vit_transforms(img_size: int = 224, mode: str = "train"):
"""
Augmentation pipeline ottimizzata per ViT.
ViT beneficia molto da augmentation aggressiva.
"""
if mode == "train":
return transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9), # RandAugment
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
transforms.RandomErasing(p=0.25) # CutOut/Erasing
])
else:
# Resize + center crop per validation/test
return transforms.Compose([
transforms.Resize(int(img_size * 1.143)), # 256 per 224
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ============================================================
# TRAINING LOOP CON WARMUP + COSINE DECAY
# ============================================================
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
"""LR schedule: linear warmup + cosine decay (standard per ViT)."""
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda)
def train_vit(
model, train_loader, val_loader,
num_epochs=30, base_lr=3e-5, weight_decay=0.05,
device="cuda", label_smoothing=0.1
):
model = model.to(device)
# AdamW con weight decay (standard per ViT)
# Escludi bias e LayerNorm dal weight decay
no_decay_params = []
decay_params = []
for name, param in model.named_parameters():
if param.requires_grad:
if 'bias' in name or 'norm' in name or 'cls_token' in name or 'pos_embed' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = torch.optim.AdamW([
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=base_lr)
total_steps = len(train_loader) * num_epochs
warmup_steps = len(train_loader) * 5 # 5 epoch di warmup
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
# Label smoothing loss
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
best_acc = 0.0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for batch_idx, (imgs, labels) in enumerate(train_loader):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, labels)
loss.backward()
# Gradient clipping (importante per ViT)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
train_loss += loss.item()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_loss = train_loss / len(train_loader)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Loss: {avg_loss:.4f} | "
f"Val Acc: {val_acc:.4f} | "
f"LR: {current_lr:.2e}")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_vit.pth")
print(f" -> Nuovo best: {best_acc:.4f}")
return best_acc
MixUp și CutMix: Augmentare avansată pentru ViT
ViTs beneficiază în special de tehnici creșterea amestecurilor. MixUp combină liniar perechi de imagini și etichetele acestora; CutMix înlocuiește o porție porțiune dreptunghiulară a unei imagini cu porțiunea corespunzătoare a alteia. Ambele tehnici îmbunătățirea generalizării și calibrării modelului.
import numpy as np
class MixUpCutMix:
"""
Combinazione di MixUp e CutMix come in DeiT e timm.
Applica randomicamente uno dei due metodi per ogni batch.
"""
def __init__(self, mixup_alpha=0.8, cutmix_alpha=1.0,
prob=0.5, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.prob = prob
self.num_classes = num_classes
def one_hot(self, labels: torch.Tensor) -> torch.Tensor:
return F.one_hot(labels, self.num_classes).float()
def mixup(self, imgs, labels_oh):
"""MixUp: interpolazione lineare."""
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
B = imgs.size(0)
idx = torch.randperm(B)
mixed_imgs = lam * imgs + (1 - lam) * imgs[idx]
mixed_labels = lam * labels_oh + (1 - lam) * labels_oh[idx]
return mixed_imgs, mixed_labels
def cutmix(self, imgs, labels_oh):
"""CutMix: ritaglia e incolla patch."""
lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
B, C, H, W = imgs.shape
idx = torch.randperm(B)
# Calcola dimensioni bounding box
cut_ratio = math.sqrt(1.0 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
# Centro casuale
cy = np.random.randint(H)
cx = np.random.randint(W)
y1 = max(0, cy - cut_h // 2)
y2 = min(H, cy + cut_h // 2)
x1 = max(0, cx - cut_w // 2)
x2 = min(W, cx + cut_w // 2)
# Applica CutMix
mixed_imgs = imgs.clone()
mixed_imgs[:, :, y1:y2, x1:x2] = imgs[idx, :, y1:y2, x1:x2]
# Ricalcola lambda effettivo
lam_actual = 1.0 - (y2 - y1) * (x2 - x1) / (H * W)
mixed_labels = lam_actual * labels_oh + (1 - lam_actual) * labels_oh[idx]
return mixed_imgs, mixed_labels
def __call__(self, imgs, labels):
labels_oh = self.one_hot(labels).to(imgs.device)
if np.random.random() < self.prob:
if np.random.random() < 0.5:
return self.mixup(imgs, labels_oh)
else:
return self.cutmix(imgs, labels_oh)
return imgs, labels_oh
# Uso nel training loop
mixup_cutmix = MixUpCutMix(num_classes=100)
# Nel training loop:
# imgs, soft_labels = mixup_cutmix(imgs, labels)
# loss = F.cross_entropy(logits, soft_labels) # Soft labels
Atenție: Vizualizați ceea ce vede ViT
Una dintre cele mai interesante caracteristici ale ViT-urilor este capacitatea de a vizualiza hărți de atenție pentru a înțelege ce regiuni ale imaginii ia în considerare modelul relevante. Tehnica Atenție lansare propaga atenția peste toate straturile pentru a obține o hartă de relevanță globală.
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def compute_attention_rollout(attn_weights_list: list,
discard_ratio: float = 0.9) -> np.ndarray:
"""
Attention Rollout (Abnar & Zuidema, 2020).
Propaga le attention maps attraverso tutti i layer.
attn_weights_list: lista di tensori [B, heads, N, N]
discard_ratio: percentuale di attention da azzerare (focus sui top)
"""
n_layers = len(attn_weights_list)
# Media su tutte le teste
rollout = None
for attn in attn_weights_list:
# attn: [B, heads, N, N] -> media teste -> [B, N, N]
attn_mean = attn.mean(dim=1) # [B, N, N]
# Aggiunge identità (residual connection)
eye = torch.eye(attn_mean.size(-1), device=attn_mean.device)
attn_mean = attn_mean + eye
attn_mean = attn_mean / attn_mean.sum(dim=-1, keepdim=True)
if rollout is None:
rollout = attn_mean
else:
rollout = torch.bmm(attn_mean, rollout)
return rollout
def visualize_vit_attention(model, image_tensor: torch.Tensor,
patch_size: int = 16):
"""
Visualizza l'attention del CLS token sull'immagine.
"""
model.eval()
with torch.no_grad():
_, attn_list = model(image_tensor.unsqueeze(0), return_attn=True)
# Calcola rollout
rollout = compute_attention_rollout(attn_list) # [1, N+1, N+1]
# Attenzione del CLS verso tutti i patch
cls_attn = rollout[0, 0, 1:] # Escludi il CLS token stesso
# Ridimensiona alla griglia dei patch
H = W = int(math.sqrt(cls_attn.size(0)))
attn_map = cls_attn.reshape(H, W).cpu().numpy()
# Normalizza e upscale alla dimensione immagine
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
attn_map_upscaled = np.kron(attn_map, np.ones((patch_size, patch_size)))
return attn_map_upscaled
# Esempio di utilizzo e visualizzazione
# model = vit_base_16(num_classes=1000)
# img_tensor = get_vit_transforms(mode="val")(Image.open("dog.jpg"))
# attn_map = visualize_vit_attention(model, img_tensor)
#
# plt.figure(figsize=(12, 4))
# plt.subplot(1, 2, 1)
# plt.imshow(img_tensor.permute(1,2,0).numpy())
# plt.title("Immagine originale")
# plt.subplot(1, 2, 2)
# plt.imshow(attn_map, cmap='inferno')
# plt.title("Attention Rollout (CLS token)")
# plt.colorbar()
# plt.tight_layout()
# plt.savefig("vit_attention.png", dpi=150)
Swin Transformer: Atenție la ferestrele ierarhice
Il Swin Transformer abordează două limitări fundamentale ale ViT standard: complexitatea pătratică a atenţiei (care limitează rezoluţia procesabilă) şi absenţa a reprezentărilor ierarhice (necesare pentru detecţie şi segmentare).
Swin împarte imaginea în ferestre care nu se suprapun și calculează atenția doar în interior a fiecărei ferestre (complexitate liniară). Între un strat și altul vin ferestre schimbare pentru a permite comunicarea între ferestrele adiacente. Structura ierarhică reduce progresiv rezoluția spațială, producând hărți de caracteristici la scară 4, cum ar fi CNN-urile tradiționale.
# Uso di Swin Transformer tramite timm (più semplice che implementare da zero)
# pip install timm
import timm
import torch
# Crea Swin-T (Tiny): 28M param, 81.3% ImageNet Top-1
swin_tiny = timm.create_model(
'swin_tiny_patch4_window7_224',
pretrained=True,
num_classes=0 # 0 = rimuovi classifier (backbone solo)
)
# Swin-B (Base): 88M param, 85.2% ImageNet Top-1
swin_base = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
num_classes=100 # Custom classifier
)
# Swin-V2-L per alta risoluzione (resolution scaling)
swin_v2 = timm.create_model(
'swinv2_large_window12to16_192to256_22kft1k',
pretrained=True,
num_classes=10
)
# Verifica feature maps gerarchiche (per detection/segmentation)
swin_backbone = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
features_only=True, # Restituisce feature a 4 scale
out_indices=(0, 1, 2, 3)
)
x = torch.randn(2, 3, 224, 224)
features = swin_backbone(x)
for i, feat in enumerate(features):
print(f"Stage {i}: {feat.shape}")
# Stage 0: torch.Size([2, 192, 56, 56])
# Stage 1: torch.Size([2, 384, 28, 28])
# Stage 2: torch.Size([2, 768, 14, 14])
# Stage 3: torch.Size([2, 1536, 7, 7])
# Fine-tuning completo con timm
from timm.loss import SoftTargetCrossEntropy
from timm.data.mixup import Mixup
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
# Parametri ottimali per fine-tuning Swin
model = timm.create_model('swin_base_patch4_window7_224',
pretrained=True, num_classes=10)
# Optimizer con parametri specifici per Swin
optimizer = create_optimizer_v2(
model,
opt='adamw',
lr=5e-5,
weight_decay=0.05,
layer_decay=0.9 # Layer-wise LR decay: layer più profondi = LR più bassa
)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Swin output: {out.shape}") # [2, 10]
Implementare optimizată: ONNX și TorchScript
Pentru implementarea în producție, este esențial să exportați modelul într-un format optimizat. ONNX permite interoperabilitatea între cadre și optimizări specifice hardware-ului; TorchScript elimină overheadul Python pentru inferență.
import torch
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np
import timm
# Modello ViT/Swin pre-addestrato e fine-tuned
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
model.load_state_dict(torch.load('best_vit.pth'))
model.eval()
# ============================================================
# EXPORT ONNX
# ============================================================
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"vit_model.onnx",
export_params=True,
opset_version=17, # ONNX opset 17 per operatori recenti
do_constant_folding=True, # Ottimizzazione grafo
input_names=['pixel_values'],
output_names=['logits'],
dynamic_axes={
'pixel_values': {0: 'batch_size'}, # Batch size dinamico
'logits': {0: 'batch_size'}
}
)
# Verifica modello ONNX
onnx_model = onnx.load("vit_model.onnx")
onnx.checker.check_model(onnx_model)
print("Modello ONNX valido!")
# Inferenza con ONNX Runtime (CPU o GPU)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
ort_session = ort.InferenceSession("vit_model.onnx", providers=providers)
# Test inferenza ONNX
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = ort_session.run(None, {'pixel_values': test_input})
print(f"ONNX output shape: {outputs[0].shape}")
# ============================================================
# TORCHSCRIPT (alternativa per deployment PyTorch)
# ============================================================
model_scripted = torch.jit.script(model)
model_scripted.save("vit_scripted.pt")
# Ricarica e usa
loaded = torch.jit.load("vit_scripted.pt")
with torch.no_grad():
out = loaded(dummy_input)
print(f"TorchScript output: {out.shape}")
# ============================================================
# BENCHMARK ONNX vs PyTorch
# ============================================================
import time
def benchmark(fn, n_runs=50, warmup=10):
for _ in range(warmup):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in range(n_runs):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = (time.perf_counter() - t0) / n_runs * 1000
return elapsed
# PyTorch
def pt_inference():
with torch.no_grad():
model(dummy_input)
# ONNX Runtime
def onnx_inference():
ort_session.run(None, {'pixel_values': test_input})
pt_ms = benchmark(pt_inference)
onnx_ms = benchmark(onnx_inference)
print(f"PyTorch: {pt_ms:.1f} ms/inference")
print(f"ONNX RT: {onnx_ms:.1f} ms/inference")
print(f"Speedup ONNX: {pt_ms/onnx_ms:.2f}x")
ViT pentru sarcini specializate: medicale, prin satelit și multimodale
ViT-urile au demonstrat o capacitate de transfer excepțională pe domenii foarte diferite de ImageNet. Trei domenii de aplicare deosebit de importante sunt calculatoare vederii medicale (radiologie, patologie digitală, dermatologie), the teledetecție (imagini din satelit, imagini cu drone) și modele multimodal (CLIP, SigLIP, LLaVA).
import timm
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
# ============================================================
# ViT PER IMAGING MEDICO (classificazione CXR)
# ============================================================
# Chest X-Ray classification con DeiT fine-tuned
class MedicalViT(nn.Module):
"""
ViT per classificazione immagini mediche.
Usa un backbone pre-addestrato su ImageNet + fine-tuning su CXR.
Considera: le immagini mediche sono spesso grayscale (convertite a 3ch)
e richiedono risoluzione maggiore (384px).
"""
def __init__(self, n_classes: int, model_name: str = "deit3_base_patch16_384",
dropout: float = 0.2):
super().__init__()
# DeiT3 a 384px: più accurato per dettagli fini nelle immagini mediche
self.backbone = timm.create_model(
model_name,
pretrained=True,
num_classes=0, # Rimuovi head originale
img_size=384
)
d_model = self.backbone.embed_dim
# Head medica con dropout aggressivo (evita overfit su dataset piccoli)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Dropout(dropout),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(dropout / 2),
nn.Linear(d_model // 2, n_classes)
)
# Congela i primi 6 layer (feature basiche = ImageNet features)
# Fine-tuna solo i layer superiori (feature ad alto livello)
total_blocks = len(self.backbone.blocks)
freeze_until = total_blocks // 2
for i, block in enumerate(self.backbone.blocks):
if i < freeze_until:
for p in block.parameters():
p.requires_grad = False
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Parametri trainabili: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x) # [B, d_model] - CLS token
return self.head(features)
# Uso per NIH Chest X-Ray Dataset (14 classi, multi-label)
medical_vit = MedicalViT(n_classes=14, dropout=0.3)
x = torch.randn(4, 3, 384, 384) # 384px per CXR
out = medical_vit(x)
print(f"CXR prediction shape: {out.shape}") # [4, 14]
# ============================================================
# CLIP: VISION-LANGUAGE PRETRAINING
# ============================================================
# CLIP usa un ViT come encoder visuale accoppiato a un Transformer testuale.
# L'addestramento contrasto allinea rappresentazioni visive e testuali.
def clip_zero_shot_classification(
images: torch.Tensor,
class_descriptions: list, # ["a photo of a cat", "a photo of a dog", ...]
model_name: str = "openai/clip-vit-base-patch32"
):
"""
Zero-shot image classification con CLIP.
Non richiede esempi di training: usa descrizioni testuali delle classi.
"""
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()
# Codifica testi e immagini nello stesso spazio embedding
with torch.no_grad():
# Text embeddings
text_inputs = processor(text=class_descriptions, return_tensors="pt",
padding=True, truncation=True)
text_features = model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Image embeddings (usa ViT internamente)
image_inputs = processor(images=images, return_tensors="pt")
image_features = model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Similarità coseno: matrice [n_images, n_classes]
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
return similarity
# Esempio: classificazione zero-shot senza training
class_names = [
"a chest X-ray showing pneumonia",
"a normal chest X-ray",
"a chest X-ray showing cardiomegaly",
"a chest X-ray with pleural effusion"
]
# similarity = clip_zero_shot_classification(images, class_names)
# print(f"Predicted class: {class_names[similarity.argmax()]}")
print("ViT multimodale (CLIP) pronto per zero-shot classification")
Optimizare ViT pentru dispozitive Edge
Implementarea hardware-ului ViT pe margine necesită strategii specifice. ViT-uri standard (86 M+ parametri) sunt prea grele pentru Raspberry Pi sau microcontrolere. Variantele mai ușoare precum ViT-Ti/16 (6M param) e MobileViT (5M param) sunt conceput pentru acest caz de utilizare, combinând puterea expresivă a atenției cu eficienţa circumvoluţiilor.
import timm
import torch
import torch.onnx
import time
import numpy as np
# ============================================================
# VARIANTI ViT LEGGERE PER EDGE
# ============================================================
edge_models = {
"vit_tiny_patch16_224": "ViT-Ti (6M, ~4ms GPU)",
"deit_tiny_patch16_224": "DeiT-Ti (5.7M, ~3ms GPU)",
"mobilevit_s": "MobileViT-S (5.6M, 4ms, ottimo CPU)",
"efficientvit_m0": "EfficientViT-M0 (2.4M, ultra-light)",
"fastvit_t8": "FastViT-T8 (4M, 3x più veloce di DeiT)",
}
def benchmark_edge_models(input_size=(1, 3, 224, 224), device="cpu", n_runs=50):
"""
Benchmark dei modelli ViT leggeri su CPU (simula edge device).
CPU benchmark e più rappresentativo di deployment su RPi/Jetson Nano.
"""
results = []
x = torch.randn(*input_size).to(device)
for model_name, description in edge_models.items():
try:
model = timm.create_model(model_name, pretrained=False, num_classes=10)
model = model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
model_size_mb = n_params * 4 / (1024**2)
# Warmup
with torch.no_grad():
for _ in range(5):
model(x)
# Benchmark
t0 = time.perf_counter()
with torch.no_grad():
for _ in range(n_runs):
model(x)
latency_ms = (time.perf_counter() - t0) / n_runs * 1000
results.append({
"model": model_name,
"desc": description,
"params_M": n_params / 1e6,
"size_mb": model_size_mb,
"latency_ms": latency_ms
})
print(f"{model_name:<35} {n_params/1e6:>5.1f}M "
f"{model_size_mb:>6.1f}MB {latency_ms:>8.1f}ms")
except Exception as e:
print(f"{model_name}: Errore - {e}")
return results
# ============================================================
# EXPORT OTTIMIZZATO PER EDGE
# ============================================================
def export_vit_for_edge(model_name: str = "vit_tiny_patch16_224",
n_classes: int = 10):
"""
Pipeline completa: carica ViT-Ti, quantizza e esporta per edge.
"""
model = timm.create_model(model_name, pretrained=False, num_classes=n_classes)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
# 1. Export ONNX con opset 17
torch.onnx.export(
model, dummy_input, f"{model_name}_edge.onnx",
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}}
)
# 2. Quantizzazione dinamica INT8 (per CPU edge)
import torch.quantization
model_quantized = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.MultiheadAttention}, dtype=torch.qint8
)
# Confronto dimensioni
original_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
print(f"Modello originale FP32: {original_size:.1f} MB")
# Salva versione quantizzata
torch.save(model_quantized.state_dict(), f"{model_name}_int8.pt")
# Test latenza quantizzata su CPU
with torch.no_grad():
for _ in range(5): model_quantized(dummy_input) # warmup
t0 = time.perf_counter()
for _ in range(50): model_quantized(dummy_input)
lat_quant = (time.perf_counter() - t0) / 50 * 1000
print(f"Latenza INT8 CPU: {lat_quant:.1f}ms")
return model_quantized
print("ViT edge export pipeline pronto")
ViT vs CNN Benchmark on Common Tasks (2025)
| Model | ImageNet Top-1 | Latență (ms) | Debit (img/s) | Params |
|---|---|---|---|---|
| ResNet-50 | 76,1% | 4,1 ms | 1.240 | 25M |
| ConvNeXt-T | 82,1% | 5,5 ms | 960 | 29M |
| DeiT-B | 83,1% | 9,2 ms | 570 | 87M |
| Swin-T | 81,3% | 6,8 ms | 740 | 28M |
| ViT-B/16 (timm) | 85,5% | 11,4 ms | 460 | 86M |
| EfficientNet-B4 | 83,0% | 7,3 ms | 690 | 19M |
Măsurat pe RTX 4090, dimensiunea lotului 32, FP16. Latență = o singură imagine, lot=1.
Avertisment: ViT nu este întotdeauna cea mai bună alegere
- Seturi mici de date (imagini <10K): CNN sau EfficientNet au rezultate mai bune fără pregătire prealabilă la scară largă. ViT necesită o mulțime de date pentru a converge corect.
- Sarcini în timp real la margine: ViT-Ti/16 are o latență de ~4 ms pe GPU, dar >100 ms pe CPU. MobileNet sau EfficientNet-Lite sunt de preferat pentru implementările CPU.
- Detectarea obiectelor pe CPU: Swin și coloana vertebrală grozavă, dar grea. YOLO cu coloană ușoară învinge Swin pe CPU în latență.
- Reglare fină cu un set de date extrem de schimbare a domeniului: CNN-urile pre-instruite se pot generaliza mai bine dacă setul de date țintă este foarte diferit de ImageNet.
Cele mai bune practici pentru ViT în producție
Lista de verificare pentru implementarea ViT
- Alege varianta potrivita: ViT-Ti/S pentru resurse limitate, ViT-B pentru calitate standard, Swin-T/S pentru detectare/segmentare, DeiT-B pentru antrenament de la zero pe scara ImageNet.
- Pre-instruire pe ImageNet-21K: începe întotdeauna de la greutăți ImageNet-21K, nu ImageNet-1K. Oferă un salt semnificativ în precizie, în special cu seturi de date mici.
- Rată scăzută de învățare pentru reglare fină: utilizați baza LR 3e-5 pentru ViT-B, cu încălziri de minim 5 epoci. LR prea mare distruge reprezentările pre-antrenate.
- Rezoluție de intrare: ViT pre-antrenat la 224px funcționează cel mai bine cu intrare de 224px. Reglarea fină la 384px îmbunătățește acuratețea, dar costă de 2,3x memoria.
- Dimensiunea lotului și acumularea gradientului: ViT beneficiază de loturi mari (256-2048). Utilizați acumularea de gradient dacă VRAM nu este suficient.
-
Precizie mixtă (BF16/FP16): activați întotdeauna
torch.autocast. ViT obține o accelerare de două ori fără pierderi de precizie. -
Atentie Flash: STATELE UNITE ALE AMERICII
torch.nn.functional.scaled_dot_product_attention(PyTorch 2.0+) sauflash-attnpentru a reduce memoria atenției cu 40%.
Concluzii
Vision Transformers au redefinit peisajul viziunii computerizate. În 2026, dihotomia ViT vs CNN și în mare parte învechit: arhitecturile hibride (ConvNeXt, CoAtNet, FastViT) se combină cel mai bun din ambele lumi, în timp ce ViT-urile pure precum EVA și SigLIP domină benchmark-urile la scară largă.
Pentru practică, fluxul de lucru optim și clar: alegeți o coloană vertebrală pre-antrenată pe seturi mari de date (ImageNet-21K, LAION), reglaj fin cu augmentare agresivă (MixUp, CutMix, RandAugment) și încălzire LR, apoi exportați în ONNX pentru o implementare optimizată. Diferența față de CNN-uri nu este doar cantitativ – capacitatea globală de atenție permite ViT să capteze relații de lungă durată în imagini, cruciale pentru sarcini complexe de înțelegere vizuală.
Următorul pas în serie este Căutare arhitectură neuronală (NAS): cum automatizează alegerea arhitecturii optime pentru o anumită sarcină și buget de calcul, trecând dincolo de selecția manuală între ViT, CNN și variantele hibride.
Următorii pași
- Articolul următor: Căutare arhitectură neuronală: AutoML pentru învățare profundă
- Înrudit: Reglaj fin cu LoRA și QLoRA
- Seria Computer Vision: Detectarea obiectelor cu Swin Transformer
- Seria MLOps: Servirea modelelor de viziune în producție







