Vision Transformer (ViT): Architektura a praktické aplikace
V roce 2020 dokument Google Research radikálně změnil počítačové vidění: „Obrázek stojí za to 16x16 slov". Intuice byla jednoduchá, ale revoluční – použijte architekturu Transformer, dominantní v NLP, přímo na obrazy ošetřením i náplasti vizuální jako žetony. Výsledek to bylo Vision Transformer (ViT), která během pár let předstihla CNN nejmodernější na ImageNet a desítky dalších benchmarků, které dláždí cestu novým generace vizuálních modelů.
Příslib ViT není jen přesnost: je všestrannost. Stejná páteř Transformátor používaný pro text lze sdílet s obrázky, což umožňuje šablony multimódy jako CLIP, DALL-E a GPT-4V. ViTs se škálují lépe než CNN s daty a výpočtem doplňkové, a varianty jako např Swin Transformer e DeiT udělali tyto modely jsou účinné i na středně velkých souborech dat bez předběžného školení na stovkách z milionů obrázků.
V této příručce vytvoříme ViT od nuly v PyTorch a prozkoumáme architektonické varianty nejdůležitější a ukážeme, jak je doladit pro konkrétní výrobní úkoly.
Co se naučíte
- Architektura ViT: vkládání záplat, poziční kódování, vizuální pozornost
- Kompletní implementace od nuly pomocí PyTorch
- Rozdíly mezi ViT-B/16, ViT-L/32, DeiT, Swin Transformer
- Jemné ladění ViT předem trénovaného na vlastních datových sadách
- Techniky augmentace dat pro ViT (MixUp, CutMix, RandAugment)
- Rozbalení pozornosti a interpretovatelnost map pozornosti
- Optimalizované nasazení: ONNX, TorchScript, okrajová zařízení
- Srovnávací test ViT vs CNN na skutečných datových sadách
Architektura ViT: Jak to funguje
Vision Transformer bere obraz jako vstup a rozděluje jej na nepřekrývající se oblasti
pevná velikost (obvykle 16x16 nebo 32x32 pixelů). Každý patch přijde zploštělý
(zploštění) a promítnuté lineárně do kótovacího vektoru d_model (vložení).
Tyto vložky, tzv patch vložení, staňte se žetony Transformeru.
Speciální token [CLS] (třída token) má předponou k sekvenci, podobně
na BERT v NLP. Po dokončení kódování se předá reprezentace tokenu CLS
klasifikační hlavu pro vytvoření konečné predikce. Poziční kódování — ve tvaru
naučený místo sinus – je přidán do vložení, aby se zachovala informace
prostor, který by bez něj byl ztracen.
# Diagramma architettura ViT
#
# Input Image (224x224x3)
# |
# v
# Patch Extraction: divide in 196 patch di 16x16
# (224/16 = 14 patch per lato -> 14*14 = 196 patch)
# |
# v
# Patch Embedding: ogni patch [768] via Linear projection
# + [CLS] token -> sequenza di 197 token
# |
# v
# + Positional Embedding (learnable, 197x768)
# |
# v
# Transformer Encoder (L strati):
# - LayerNorm
# - Multi-Head Self-Attention (h heads)
# - Residual connection
# - LayerNorm
# - MLP (d_model -> 4*d_model -> d_model)
# - Residual connection
# |
# v
# [CLS] token representation
# |
# v
# MLP Head -> num_classes output
# Varianti standard:
# ViT-B/16: d_model=768, L=12, h=12 | ~86M param
# ViT-L/16: d_model=1024, L=24, h=16 | ~307M param
# ViT-H/14: d_model=1280, L=32, h=16 | ~632M param
Implementace ViT od nuly
Pojďme postavit kompletní ViT v PyTorch. Začněme od základní složky: Patch Embedding.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import math
# ============================================================
# 1. PATCH EMBEDDING
# ============================================================
class PatchEmbedding(nn.Module):
"""
Converte un'immagine in una sequenza di patch embedding.
Metodo 1: Convolution (efficiente, equivalente a patch+linear)
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# Equivalente a: flatten ogni patch + proiezione lineare
# Ma implementato come Conv2d per efficienza
self.projection = nn.Sequential(
# Divide in patch e proietta
nn.Conv2d(in_channels, d_model,
kernel_size=patch_size, stride=patch_size),
# [B, d_model, H/P, W/P] -> [B, n_patches, d_model]
Rearrange('b d h w -> b (h w) d')
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, C, H, W]
return self.projection(x) # [B, n_patches, d_model]
# ============================================================
# 2. MULTI-HEAD SELF ATTENTION per ViT
# ============================================================
class ViTAttention(nn.Module):
"""Multi-head self-attention con dropout."""
def __init__(self, d_model=768, n_heads=12, attn_dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = self.head_dim ** -0.5
# QKV projection
self.qkv = nn.Linear(d_model, d_model * 3, bias=True)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor):
B, N, C = x.shape
# Calcola Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Ognuno: [B, heads, N, head_dim]
# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
attn_weights = attn # Salva per attention rollout
attn = self.attn_drop(attn)
# Weighted sum + proiezione output
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, attn_weights
# ============================================================
# 3. TRANSFORMER ENCODER BLOCK
# ============================================================
class ViTBlock(nn.Module):
"""Singolo blocco Transformer per ViT."""
def __init__(self, d_model=768, n_heads=12, mlp_ratio=4.0,
dropout=0.0, attn_dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = ViTAttention(d_model, n_heads, attn_dropout)
self.norm2 = nn.LayerNorm(d_model)
mlp_dim = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, d_model),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor):
# Pre-norm + residual connection
attn_out, attn_weights = self.attn(self.norm1(x))
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x, attn_weights
# ============================================================
# 4. VISION TRANSFORMER COMPLETO
# ============================================================
class VisionTransformer(nn.Module):
"""
Vision Transformer (ViT) completo.
Paper: "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2020)
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
num_classes: int = 1000,
d_model: int = 768,
depth: int = 12,
n_heads: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
attn_dropout: float = 0.0,
representation_size: int = None # Pre-logit layer (opzionale)
):
super().__init__()
self.num_classes = num_classes
self.d_model = d_model
# Patch + Position Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, d_model)
n_patches = self.patch_embed.n_patches
# Token CLS e positional embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, d_model))
self.pos_drop = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
ViTBlock(d_model, n_heads, mlp_ratio, dropout, attn_dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(d_model)
# Classification head
if representation_size is not None:
self.pre_logits = nn.Sequential(
nn.Linear(d_model, representation_size),
nn.Tanh()
)
else:
self.pre_logits = nn.Identity()
self.head = nn.Linear(
representation_size if representation_size else d_model,
num_classes
)
# Inizializzazione pesi
self._init_weights()
def _init_weights(self):
"""Inizializzazione standard per ViT."""
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor, return_attn: bool = False):
B = x.shape[0]
# 1. Patch embedding
x = self.patch_embed(x) # [B, n_patches, d_model]
# 2. Prepend CLS token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
x = torch.cat([cls_tokens, x], dim=1) # [B, n_patches+1, d_model]
# 3. Add positional embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# 4. Transformer blocks
attn_weights_list = []
for block in self.blocks:
x, attn_weights = block(x)
attn_weights_list.append(attn_weights)
# 5. Layer norm finale
x = self.norm(x)
# 6. Usa solo il CLS token per classificazione
cls_output = x[:, 0]
cls_output = self.pre_logits(cls_output)
logits = self.head(cls_output)
if return_attn:
return logits, attn_weights_list
return logits
# ============================================================
# 5. CREAZIONE VARIANTI STANDARD
# ============================================================
def vit_base_16(num_classes=1000, **kwargs):
"""ViT-B/16: 86M parametri, input 224x224."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=768,
depth=12, n_heads=12, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_large_16(num_classes=1000, **kwargs):
"""ViT-L/16: 307M parametri, input 224x224."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=1024,
depth=24, n_heads=16, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
def vit_tiny_16(num_classes=1000, **kwargs):
"""ViT-Ti/16: ~6M parametri, per edge/mobile."""
return VisionTransformer(
img_size=224, patch_size=16, d_model=192,
depth=12, n_heads=3, mlp_ratio=4.0,
num_classes=num_classes, **kwargs
)
# Test
model = vit_base_16(num_classes=100)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}") # [2, 100]
print(f"Parametri: {sum(p.numel() for p in model.parameters()):,}")
# Parametri: 85,880,164
Architektonické varianty: DeiT, Swin a BEiT
Původní ViT vyžadoval obrovské množství dat (JFT-300M, 300 milionů obrázků). předčí CNN. Toto omezení vedlo k vývoji datově efektivnějších variant:
| Model | Rok | Klíčová inovace | ImageNet Top-1 | Parametry |
|---|---|---|---|---|
| ViT-B/16 | 2020 | První ViT, vyžaduje JFT-300M | 81,8 % | 86 mil |
| DeiT-B | 2021 | Destilace od učitele CNN, pouze ImageNet | 83,1 % | 87 mil |
| Swin-B | 2021 | Posunuté okno Pozor, hierarchické | 85,2 % | 88 milionů |
| BEiT-L | 2022 | Maskované obrazové modelování (BERT pro vidění) | 87,4 % | 307 mil |
| DeiT III-H | 2022 | Pokročilý tréninkový recept | 87,7 % | 632 mil |
| ViT-G (EVA) | 2023 | Měřítka na 1B param, CLIP předtrénink | 89,6 % | 1,0B |
DeiT (datově efektivní obrazové transformátory) Facebook AI a pravděpodobně varianta nejdůležitější pro praxi: představuje destilační žeton která vám umožní učit se od učitele CNN (jako je RegNet nebo ConvNext), čímž získáte vynikající výkon pouze s ImageNet-1K.
Swin Transformer řeší problém kvadratické složitosti pozornosti představení ShiftedWindows: pozornost se počítá v rámci místních oken spíše než přes celý obraz, s lineárními výpočetními náklady vzhledem k obrazu. Swine vytváří hierarchické reprezentace (jako CNN) a je preferovanou páteří pro detekci a segmentace.
Jemné doladění ViT Předtrénováno
Nejpraktičtější způsob, jak používat ViTs ve výrobě, je začít od předem vyškoleného modelu ImageNet-21K a dolaďte svou datovou sadu. Hugging Face Transformers nabízí vše Core ViT modely s jednotným API.
# pip install transformers timm torch torchvision datasets
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
ViTForImageClassification, ViTImageProcessor,
AutoImageProcessor
)
from torchvision import datasets, transforms
import os
# ============================================================
# FINE-TUNING ViT-B/16 su Dataset Custom
# ============================================================
class ViTFineTuner(nn.Module):
"""
ViT pre-addestrato con classification head custom.
Supporta fine-tuning parziale o completo.
"""
def __init__(self, num_classes: int, model_name: str = "google/vit-base-patch16-224",
freeze_backbone: bool = False):
super().__init__()
# Carica ViT pre-addestrato da HuggingFace
self.vit = ViTForImageClassification.from_pretrained(
model_name,
num_labels=num_classes,
ignore_mismatched_sizes=True # Permette cambio num_classes
)
if freeze_backbone:
# Congela tutto tranne il classification head
for param in self.vit.vit.parameters():
param.requires_grad = False
# Solo il classifier rimane trainable
print(f"Parametri trainabili: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
def forward(self, x):
outputs = self.vit(pixel_values=x)
return outputs.logits
# ============================================================
# DATA AUGMENTATION per ViT
# ============================================================
def get_vit_transforms(img_size: int = 224, mode: str = "train"):
"""
Augmentation pipeline ottimizzata per ViT.
ViT beneficia molto da augmentation aggressiva.
"""
if mode == "train":
return transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9), # RandAugment
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
transforms.RandomErasing(p=0.25) # CutOut/Erasing
])
else:
# Resize + center crop per validation/test
return transforms.Compose([
transforms.Resize(int(img_size * 1.143)), # 256 per 224
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ============================================================
# TRAINING LOOP CON WARMUP + COSINE DECAY
# ============================================================
import math
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
"""LR schedule: linear warmup + cosine decay (standard per ViT)."""
def lr_lambda(step):
if step < warmup_steps:
return float(step) / float(max(1, warmup_steps))
progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda)
def train_vit(
model, train_loader, val_loader,
num_epochs=30, base_lr=3e-5, weight_decay=0.05,
device="cuda", label_smoothing=0.1
):
model = model.to(device)
# AdamW con weight decay (standard per ViT)
# Escludi bias e LayerNorm dal weight decay
no_decay_params = []
decay_params = []
for name, param in model.named_parameters():
if param.requires_grad:
if 'bias' in name or 'norm' in name or 'cls_token' in name or 'pos_embed' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = torch.optim.AdamW([
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=base_lr)
total_steps = len(train_loader) * num_epochs
warmup_steps = len(train_loader) * 5 # 5 epoch di warmup
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
# Label smoothing loss
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
best_acc = 0.0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for batch_idx, (imgs, labels) in enumerate(train_loader):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(imgs)
loss = criterion(logits, labels)
loss.backward()
# Gradient clipping (importante per ViT)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
train_loss += loss.item()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
avg_loss = train_loss / len(train_loader)
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}/{num_epochs} | "
f"Loss: {avg_loss:.4f} | "
f"Val Acc: {val_acc:.4f} | "
f"LR: {current_lr:.2e}")
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "best_vit.pth")
print(f" -> Nuovo best: {best_acc:.4f}")
return best_acc
MixUp a CutMix: Pokročilá augmentace pro ViT
ViTs těží zejména z technik mixup augmentace. MixUp lineárně kombinuje dvojice obrázků a jejich popisky; CutMix nahradí porci obdélníková část jednoho obrázku s odpovídající částí jiného. Obě techniky zlepšit zobecnění a kalibraci modelu.
import numpy as np
class MixUpCutMix:
"""
Combinazione di MixUp e CutMix come in DeiT e timm.
Applica randomicamente uno dei due metodi per ogni batch.
"""
def __init__(self, mixup_alpha=0.8, cutmix_alpha=1.0,
prob=0.5, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.prob = prob
self.num_classes = num_classes
def one_hot(self, labels: torch.Tensor) -> torch.Tensor:
return F.one_hot(labels, self.num_classes).float()
def mixup(self, imgs, labels_oh):
"""MixUp: interpolazione lineare."""
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
B = imgs.size(0)
idx = torch.randperm(B)
mixed_imgs = lam * imgs + (1 - lam) * imgs[idx]
mixed_labels = lam * labels_oh + (1 - lam) * labels_oh[idx]
return mixed_imgs, mixed_labels
def cutmix(self, imgs, labels_oh):
"""CutMix: ritaglia e incolla patch."""
lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
B, C, H, W = imgs.shape
idx = torch.randperm(B)
# Calcola dimensioni bounding box
cut_ratio = math.sqrt(1.0 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
# Centro casuale
cy = np.random.randint(H)
cx = np.random.randint(W)
y1 = max(0, cy - cut_h // 2)
y2 = min(H, cy + cut_h // 2)
x1 = max(0, cx - cut_w // 2)
x2 = min(W, cx + cut_w // 2)
# Applica CutMix
mixed_imgs = imgs.clone()
mixed_imgs[:, :, y1:y2, x1:x2] = imgs[idx, :, y1:y2, x1:x2]
# Ricalcola lambda effettivo
lam_actual = 1.0 - (y2 - y1) * (x2 - x1) / (H * W)
mixed_labels = lam_actual * labels_oh + (1 - lam_actual) * labels_oh[idx]
return mixed_imgs, mixed_labels
def __call__(self, imgs, labels):
labels_oh = self.one_hot(labels).to(imgs.device)
if np.random.random() < self.prob:
if np.random.random() < 0.5:
return self.mixup(imgs, labels_oh)
else:
return self.cutmix(imgs, labels_oh)
return imgs, labels_oh
# Uso nel training loop
mixup_cutmix = MixUpCutMix(num_classes=100)
# Nel training loop:
# imgs, soft_labels = mixup_cutmix(imgs, labels)
# loss = F.cross_entropy(logits, soft_labels) # Soft labels
Zavedení pozornosti: Vizualizujte, co vidí ViT
Jednou z nejzajímavějších funkcí ViTs je schopnost vizualizace mapy pozornosti abyste pochopili, které oblasti obrazu model zvažuje relevantní. Technika Pozornost Rollout propaguje pozornost napříč všemi vrstvami získat mapu globálního významu.
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def compute_attention_rollout(attn_weights_list: list,
discard_ratio: float = 0.9) -> np.ndarray:
"""
Attention Rollout (Abnar & Zuidema, 2020).
Propaga le attention maps attraverso tutti i layer.
attn_weights_list: lista di tensori [B, heads, N, N]
discard_ratio: percentuale di attention da azzerare (focus sui top)
"""
n_layers = len(attn_weights_list)
# Media su tutte le teste
rollout = None
for attn in attn_weights_list:
# attn: [B, heads, N, N] -> media teste -> [B, N, N]
attn_mean = attn.mean(dim=1) # [B, N, N]
# Aggiunge identità (residual connection)
eye = torch.eye(attn_mean.size(-1), device=attn_mean.device)
attn_mean = attn_mean + eye
attn_mean = attn_mean / attn_mean.sum(dim=-1, keepdim=True)
if rollout is None:
rollout = attn_mean
else:
rollout = torch.bmm(attn_mean, rollout)
return rollout
def visualize_vit_attention(model, image_tensor: torch.Tensor,
patch_size: int = 16):
"""
Visualizza l'attention del CLS token sull'immagine.
"""
model.eval()
with torch.no_grad():
_, attn_list = model(image_tensor.unsqueeze(0), return_attn=True)
# Calcola rollout
rollout = compute_attention_rollout(attn_list) # [1, N+1, N+1]
# Attenzione del CLS verso tutti i patch
cls_attn = rollout[0, 0, 1:] # Escludi il CLS token stesso
# Ridimensiona alla griglia dei patch
H = W = int(math.sqrt(cls_attn.size(0)))
attn_map = cls_attn.reshape(H, W).cpu().numpy()
# Normalizza e upscale alla dimensione immagine
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
attn_map_upscaled = np.kron(attn_map, np.ones((patch_size, patch_size)))
return attn_map_upscaled
# Esempio di utilizzo e visualizzazione
# model = vit_base_16(num_classes=1000)
# img_tensor = get_vit_transforms(mode="val")(Image.open("dog.jpg"))
# attn_map = visualize_vit_attention(model, img_tensor)
#
# plt.figure(figsize=(12, 4))
# plt.subplot(1, 2, 1)
# plt.imshow(img_tensor.permute(1,2,0).numpy())
# plt.title("Immagine originale")
# plt.subplot(1, 2, 2)
# plt.imshow(attn_map, cmap='inferno')
# plt.title("Attention Rollout (CLS token)")
# plt.colorbar()
# plt.tight_layout()
# plt.savefig("vit_attention.png", dpi=150)
Swin Transformer: Pozor na hierarchická okna
Il Swin Transformer řeší dvě základní omezení standardního ViT: kvadratická složitost pozornosti (která omezuje zpracovatelné rozlišení) a absence hierarchických reprezentací (nezbytných pro detekci a segmentaci).
Swin rozdělí obrázek na nepřekrývající se okna a počítá pozornost pouze uvnitř každého okna (lineární složitost). Mezi jednou vrstvou a druhou přicházejí okna posun umožňující komunikaci mezi sousedními okny. Hierarchická struktura progresivně snižuje prostorové rozlišení a vytváří mapy rysů ve 4 měřítku, jako jsou např tradiční CNN.
# Uso di Swin Transformer tramite timm (più semplice che implementare da zero)
# pip install timm
import timm
import torch
# Crea Swin-T (Tiny): 28M param, 81.3% ImageNet Top-1
swin_tiny = timm.create_model(
'swin_tiny_patch4_window7_224',
pretrained=True,
num_classes=0 # 0 = rimuovi classifier (backbone solo)
)
# Swin-B (Base): 88M param, 85.2% ImageNet Top-1
swin_base = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
num_classes=100 # Custom classifier
)
# Swin-V2-L per alta risoluzione (resolution scaling)
swin_v2 = timm.create_model(
'swinv2_large_window12to16_192to256_22kft1k',
pretrained=True,
num_classes=10
)
# Verifica feature maps gerarchiche (per detection/segmentation)
swin_backbone = timm.create_model(
'swin_base_patch4_window7_224',
pretrained=True,
features_only=True, # Restituisce feature a 4 scale
out_indices=(0, 1, 2, 3)
)
x = torch.randn(2, 3, 224, 224)
features = swin_backbone(x)
for i, feat in enumerate(features):
print(f"Stage {i}: {feat.shape}")
# Stage 0: torch.Size([2, 192, 56, 56])
# Stage 1: torch.Size([2, 384, 28, 28])
# Stage 2: torch.Size([2, 768, 14, 14])
# Stage 3: torch.Size([2, 1536, 7, 7])
# Fine-tuning completo con timm
from timm.loss import SoftTargetCrossEntropy
from timm.data.mixup import Mixup
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
# Parametri ottimali per fine-tuning Swin
model = timm.create_model('swin_base_patch4_window7_224',
pretrained=True, num_classes=10)
# Optimizer con parametri specifici per Swin
optimizer = create_optimizer_v2(
model,
opt='adamw',
lr=5e-5,
weight_decay=0.05,
layer_decay=0.9 # Layer-wise LR decay: layer più profondi = LR più bassa
)
x = torch.randn(2, 3, 224, 224)
out = model(x)
print(f"Swin output: {out.shape}") # [2, 10]
Optimalizované nasazení: ONNX a TorchScript
Pro produkční nasazení je nezbytné exportovat model v optimalizovaném formátu. ONNX umožňuje interoperabilitu mezi frameworky a hardwarově specifické optimalizace; TorchScript eliminuje režii Pythonu pro odvození.
import torch
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np
import timm
# Modello ViT/Swin pre-addestrato e fine-tuned
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=10)
model.load_state_dict(torch.load('best_vit.pth'))
model.eval()
# ============================================================
# EXPORT ONNX
# ============================================================
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"vit_model.onnx",
export_params=True,
opset_version=17, # ONNX opset 17 per operatori recenti
do_constant_folding=True, # Ottimizzazione grafo
input_names=['pixel_values'],
output_names=['logits'],
dynamic_axes={
'pixel_values': {0: 'batch_size'}, # Batch size dinamico
'logits': {0: 'batch_size'}
}
)
# Verifica modello ONNX
onnx_model = onnx.load("vit_model.onnx")
onnx.checker.check_model(onnx_model)
print("Modello ONNX valido!")
# Inferenza con ONNX Runtime (CPU o GPU)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
ort_session = ort.InferenceSession("vit_model.onnx", providers=providers)
# Test inferenza ONNX
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = ort_session.run(None, {'pixel_values': test_input})
print(f"ONNX output shape: {outputs[0].shape}")
# ============================================================
# TORCHSCRIPT (alternativa per deployment PyTorch)
# ============================================================
model_scripted = torch.jit.script(model)
model_scripted.save("vit_scripted.pt")
# Ricarica e usa
loaded = torch.jit.load("vit_scripted.pt")
with torch.no_grad():
out = loaded(dummy_input)
print(f"TorchScript output: {out.shape}")
# ============================================================
# BENCHMARK ONNX vs PyTorch
# ============================================================
import time
def benchmark(fn, n_runs=50, warmup=10):
for _ in range(warmup):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in range(n_runs):
fn()
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = (time.perf_counter() - t0) / n_runs * 1000
return elapsed
# PyTorch
def pt_inference():
with torch.no_grad():
model(dummy_input)
# ONNX Runtime
def onnx_inference():
ort_session.run(None, {'pixel_values': test_input})
pt_ms = benchmark(pt_inference)
onnx_ms = benchmark(onnx_inference)
print(f"PyTorch: {pt_ms:.1f} ms/inference")
print(f"ONNX RT: {onnx_ms:.1f} ms/inference")
print(f"Speedup ONNX: {pt_ms/onnx_ms:.2f}x")
ViT pro specializované úkoly: lékařské, satelitní a multimodální
ViTs prokázaly výjimečnou přenosovou kapacitu na doménách velmi odlišných od ImageNet. Jsou to tři zvláště důležité oblasti použití počítače lékařské vidění (radiologie, digitální patologie, dermatologie), dálkového průzkumu Země (satelitní snímky, snímky z dronů) a modely multimodální (CLIP, SigLIP, LLaVA).
import timm
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
# ============================================================
# ViT PER IMAGING MEDICO (classificazione CXR)
# ============================================================
# Chest X-Ray classification con DeiT fine-tuned
class MedicalViT(nn.Module):
"""
ViT per classificazione immagini mediche.
Usa un backbone pre-addestrato su ImageNet + fine-tuning su CXR.
Considera: le immagini mediche sono spesso grayscale (convertite a 3ch)
e richiedono risoluzione maggiore (384px).
"""
def __init__(self, n_classes: int, model_name: str = "deit3_base_patch16_384",
dropout: float = 0.2):
super().__init__()
# DeiT3 a 384px: più accurato per dettagli fini nelle immagini mediche
self.backbone = timm.create_model(
model_name,
pretrained=True,
num_classes=0, # Rimuovi head originale
img_size=384
)
d_model = self.backbone.embed_dim
# Head medica con dropout aggressivo (evita overfit su dataset piccoli)
self.head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Dropout(dropout),
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Dropout(dropout / 2),
nn.Linear(d_model // 2, n_classes)
)
# Congela i primi 6 layer (feature basiche = ImageNet features)
# Fine-tuna solo i layer superiori (feature ad alto livello)
total_blocks = len(self.backbone.blocks)
freeze_until = total_blocks // 2
for i, block in enumerate(self.backbone.blocks):
if i < freeze_until:
for p in block.parameters():
p.requires_grad = False
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"Parametri trainabili: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x) # [B, d_model] - CLS token
return self.head(features)
# Uso per NIH Chest X-Ray Dataset (14 classi, multi-label)
medical_vit = MedicalViT(n_classes=14, dropout=0.3)
x = torch.randn(4, 3, 384, 384) # 384px per CXR
out = medical_vit(x)
print(f"CXR prediction shape: {out.shape}") # [4, 14]
# ============================================================
# CLIP: VISION-LANGUAGE PRETRAINING
# ============================================================
# CLIP usa un ViT come encoder visuale accoppiato a un Transformer testuale.
# L'addestramento contrasto allinea rappresentazioni visive e testuali.
def clip_zero_shot_classification(
images: torch.Tensor,
class_descriptions: list, # ["a photo of a cat", "a photo of a dog", ...]
model_name: str = "openai/clip-vit-base-patch32"
):
"""
Zero-shot image classification con CLIP.
Non richiede esempi di training: usa descrizioni testuali delle classi.
"""
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()
# Codifica testi e immagini nello stesso spazio embedding
with torch.no_grad():
# Text embeddings
text_inputs = processor(text=class_descriptions, return_tensors="pt",
padding=True, truncation=True)
text_features = model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Image embeddings (usa ViT internamente)
image_inputs = processor(images=images, return_tensors="pt")
image_features = model.get_image_features(**image_inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Similarità coseno: matrice [n_images, n_classes]
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
return similarity
# Esempio: classificazione zero-shot senza training
class_names = [
"a chest X-ray showing pneumonia",
"a normal chest X-ray",
"a chest X-ray showing cardiomegaly",
"a chest X-ray with pleural effusion"
]
# similarity = clip_zero_shot_classification(images, class_names)
# print(f"Predicted class: {class_names[similarity.argmax()]}")
print("ViT multimodale (CLIP) pronto per zero-shot classification")
Optimalizace ViT pro zařízení Edge
Nasazení ViT na okrajový hardware vyžaduje specifické strategie. Standardní ViTs (86M+ parametry) jsou příliš těžké pro Raspberry Pi nebo mikrokontroléry. Lehčí varianty jako např ViT-Ti/16 (6M parametr) e MobileViT (param 5M). navrženo pro tento případ použití a kombinuje výrazovou sílu pozornosti s účinnost konvolucí.
import timm
import torch
import torch.onnx
import time
import numpy as np
# ============================================================
# VARIANTI ViT LEGGERE PER EDGE
# ============================================================
edge_models = {
"vit_tiny_patch16_224": "ViT-Ti (6M, ~4ms GPU)",
"deit_tiny_patch16_224": "DeiT-Ti (5.7M, ~3ms GPU)",
"mobilevit_s": "MobileViT-S (5.6M, 4ms, ottimo CPU)",
"efficientvit_m0": "EfficientViT-M0 (2.4M, ultra-light)",
"fastvit_t8": "FastViT-T8 (4M, 3x più veloce di DeiT)",
}
def benchmark_edge_models(input_size=(1, 3, 224, 224), device="cpu", n_runs=50):
"""
Benchmark dei modelli ViT leggeri su CPU (simula edge device).
CPU benchmark e più rappresentativo di deployment su RPi/Jetson Nano.
"""
results = []
x = torch.randn(*input_size).to(device)
for model_name, description in edge_models.items():
try:
model = timm.create_model(model_name, pretrained=False, num_classes=10)
model = model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
model_size_mb = n_params * 4 / (1024**2)
# Warmup
with torch.no_grad():
for _ in range(5):
model(x)
# Benchmark
t0 = time.perf_counter()
with torch.no_grad():
for _ in range(n_runs):
model(x)
latency_ms = (time.perf_counter() - t0) / n_runs * 1000
results.append({
"model": model_name,
"desc": description,
"params_M": n_params / 1e6,
"size_mb": model_size_mb,
"latency_ms": latency_ms
})
print(f"{model_name:<35} {n_params/1e6:>5.1f}M "
f"{model_size_mb:>6.1f}MB {latency_ms:>8.1f}ms")
except Exception as e:
print(f"{model_name}: Errore - {e}")
return results
# ============================================================
# EXPORT OTTIMIZZATO PER EDGE
# ============================================================
def export_vit_for_edge(model_name: str = "vit_tiny_patch16_224",
n_classes: int = 10):
"""
Pipeline completa: carica ViT-Ti, quantizza e esporta per edge.
"""
model = timm.create_model(model_name, pretrained=False, num_classes=n_classes)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
# 1. Export ONNX con opset 17
torch.onnx.export(
model, dummy_input, f"{model_name}_edge.onnx",
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}}
)
# 2. Quantizzazione dinamica INT8 (per CPU edge)
import torch.quantization
model_quantized = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.MultiheadAttention}, dtype=torch.qint8
)
# Confronto dimensioni
original_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
print(f"Modello originale FP32: {original_size:.1f} MB")
# Salva versione quantizzata
torch.save(model_quantized.state_dict(), f"{model_name}_int8.pt")
# Test latenza quantizzata su CPU
with torch.no_grad():
for _ in range(5): model_quantized(dummy_input) # warmup
t0 = time.perf_counter()
for _ in range(50): model_quantized(dummy_input)
lat_quant = (time.perf_counter() - t0) / 50 * 1000
print(f"Latenza INT8 CPU: {lat_quant:.1f}ms")
return model_quantized
print("ViT edge export pipeline pronto")
ViT vs CNN Benchmark on Common Tasks (2025)
| Model | ImageNet Top-1 | Latence (ms) | Propustnost (img/s) | Parametry |
|---|---|---|---|---|
| ResNet-50 | 76,1 % | 4,1 ms | 1,240 | 25 mil |
| ConvNeXt-T | 82,1 % | 5,5 ms | 960 | 29 mil |
| DeiT-B | 83,1 % | 9,2 ms | 570 | 87 mil |
| Swin-T | 81,3 % | 6,8 ms | 740 | 28 mil |
| ViT-B/16 (timm) | 85,5 % | 11,4 ms | 460 | 86 mil |
| EfficientNet-B4 | 83,0 % | 7,3 ms | 690 | 19 mil |
Měřeno na RTX 4090, velikost šarže 32, FP16. Latence = jeden obrázek, dávka = 1.
Upozornění: ViT není vždy tou nejlepší volbou
- Malé datové sady (<10 000 obrázků): CNN nebo EfficientNet fungují lépe bez rozsáhlého předběžného školení. ViT vyžaduje ke správné konvergaci hodně dat.
- Úkoly v reálném čase na okraji: ViT-Ti/16 má ~4ms latenci na GPU, ale >100ms na CPU. Pro nasazení CPU jsou vhodnější MobileNet nebo EfficientNet-Lite.
- Detekce objektů na CPU: Swin a skvělá páteř, ale těžká. YOLO s lehkou páteří poráží Swin na CPU v latenci.
- Jemné ladění s extrémním posunem domény: Předtrénované CNN mohou lépe zobecnit, pokud se cílová datová sada velmi liší od ImageNet.
Nejlepší postupy pro ViT ve výrobě
Kontrolní seznam pro Deployment ViT
- Vyberte si správnou variantu: ViT-Ti/S pro omezené zdroje, ViT-B pro standardní kvalitu, Swin-T/S pro detekci/segmentaci, DeiT-B pro trénink od nuly na ImageNet-scale.
- Předškolní příprava na ImageNet-21K: vždy začíná od vah ImageNet-21K, nikoli ImageNet-1K. Nabízí výrazný skok v přesnosti, zejména u malých datových sad.
- Nízká rychlost učení pro jemné doladění: použijte základnu LR 3e-5 pro ViT-B, se zahříváním alespoň 5 epoch. Příliš vysoké LR ničí předem trénované reprezentace.
- Vstupní rozlišení: Nejlépe funguje ViT předem natrénovaný na 224px se vstupem 224px. Jemné doladění na 384 pixelů zvyšuje přesnost, ale stojí 2,3x paměť.
- Velikost dávky a akumulace gradientu: ViT těží z velkých velikostí dávek (256-2048). Pokud VRAM nestačí, použijte akumulaci gradientu.
-
Smíšená přesnost (BF16/FP16): vždy povolit
torch.autocast. ViT získává 2x zrychlení bez ztráty přesnosti. -
Pozor na blesk: USA
torch.nn.functional.scaled_dot_product_attention(PyTorch 2.0+) neboflash-attnsnížit pozornost paměti o 40 %.
Závěry
Vision Transformers nově definovali prostředí počítačového vidění. V roce 2026 dichotomie ViT vs CNN a do značné míry zastaralé: hybridní architektury (ConvNeXt, CoAtNet, FastViT) kombinují to nejlepší z obou světů, zatímco čisté ViTs jako EVA a SigLIP dominují rozsáhlým benchmarkům.
Pro praxi optimální a jasný pracovní postup: vyberte si páteř předem vycvičenou na velké datové sady (ImageNet-21K, LAION), dolaďte pomocí agresivní augmentace (MixUp, CutMix, RandAugment) a LR warmup, poté exportujte do ONNX pro optimalizované nasazení. Rozdíl oproti CNN není to jen kvantitativní – schopnost globální pozornosti umožňuje ViT zaujmout vztahy na dlouhé vzdálenosti v obrazech, klíčové pro komplexní úkoly vizuálního porozumění.
Dalším krokem v řadě je Neural Architecture Search (NAS): jak automatizovat výběr optimální architektury pro danou úlohu a výpočetní rozpočet, přesahující manuální výběr mezi ViT, CNN a hybridními variantami.
Další kroky
- Další článek: Hledání neuronové architektury: AutoML pro hluboké učení
- Související: Jemné doladění pomocí LoRA a QLoRA
- Řada Computer Vision: Detekce objektů pomocí Swin Transformer
- Řada MLOps: Servírování modelů Vision ve výrobě







