Segmentarea semantică și a instanțelor: U-Net, Mask R-CNN și SAM
Segmentarea imaginii reprezintă cel mai granular nivel de înțelegere vizuală: în loc să știi „există o tumoare în această imagine” (clasificare) sau „se găsește tumora în această zonă” (detecție), vrem să știm exact care pixeli aparțin tumorii. Această precizie perfectă a pixelilor este fundamentală în medicină, chirurgie robotică, conducere autonomă și controlul calității industrial.
În acest articol vom explora cele mai importante arhitecturi pentru segmentare: U-Net (modelul care a revoluționat segmentarea medicală), Mască R-CNN (standardul de aur, de exemplu, segmentarea) e SAM (Segmentează Anything Model de Meta AI, care a redefinit limitele posibilului).
Ce vei învăța
- Arhitectura U-Net: codificator-decodor cu conexiuni de salt pentru segmentarea medicala
- Implementarea U-Net de la zero în PyTorch cu instruire privind seturile de date medicale
- Masca R-CNN: segmentare a instanțelor cu casete de delimitare + măști binare
- Segment Anything Model (SAM): segmentare zero-shot cu indicații vizuale
- Valori de evaluare: Dice Score, IoU, Precision/Recall pentru segmentare
- Tehnici de post-procesare: CRF, morfologie matematică
- Studiu de caz: segmentarea plămânilor din radiografii (set de date open source)
- Implementarea modelelor de segmentare in productie
1. Fundamentele Segmentării
1.1 Tipuri de segmentare
Taxonomia de segmentare
| Tip | Distinge Instanțele | Contextul clasamentului | Ieșiri | Arhitecturi |
|---|---|---|---|---|
| Semantică | No | Si | Hartă HxW cu etichete pe pixel | U-Net, DeepLabv3, SegFormer |
| De exemplu | Si | Nu (doar „lucruri”) | Mască binară pentru obiect | Mască R-CNN, SOLOv2, YOLACT |
| Panoptic | Da (pentru „lucruri”) | Da (pentru „lucruri”) | Instanță unificată+hartă semantică | Panoptic FPN, Mask2Former |
| Interactiv | Da (cu prompt) | Depinde de prompt | Masca condusă de clic/bbox | SAM, SAM2, ClickSEG |
1.2 Măsuri de evaluare
Pentru segmentare, sunt utilizate valori specifice care măsoară suprapunerea pixel cu pixel între masca prezisă și adevărul de bază:
import torch
import numpy as np
from typing import Union
def compute_iou(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> float:
"""
Intersection over Union per segmentazione binaria.
pred, target: tensori [H, W] o [B, H, W] con valori in [0,1]
"""
pred_binary = (pred >= threshold).bool()
target_binary = target.bool()
intersection = (pred_binary & target_binary).float().sum()
union = (pred_binary | target_binary).float().sum()
if union == 0:
return 1.0 # caso degenere: entrambe vuote
return float(intersection / union)
def dice_score(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5,
smooth: float = 1.0) -> float:
"""
Dice Score (F1 per segmentazione): 2*|X intersect Y| / (|X| + |Y|)
Preferito in ambito medico perchè meno sensibile agli sbilanciamenti.
Valore: 0 (peggio) -> 1 (perfetto)
"""
pred_binary = (pred >= threshold).float()
target_binary = target.float()
intersection = (pred_binary * target_binary).sum()
dice = (2.0 * intersection + smooth) / (pred_binary.sum() + target_binary.sum() + smooth)
return float(dice)
def compute_multiclass_miou(pred_logits: torch.Tensor, targets: torch.Tensor,
num_classes: int, ignore_index: int = 255) -> float:
"""
mIoU per segmentazione semantica multi-classe.
pred_logits: [B, C, H, W] - logit grezzi
targets: [B, H, W] - indici di classe 0..num_classes-1
"""
preds = pred_logits.argmax(dim=1) # [B, H, W]
ious = []
for cls in range(num_classes):
pred_cls = preds == cls
true_cls = targets == cls
valid = targets != ignore_index
pred_cls = pred_cls & valid
true_cls = true_cls & valid
intersection = (pred_cls & true_cls).sum().float()
union = (pred_cls | true_cls).sum().float()
if union > 0:
ious.append(float(intersection / union))
return float(np.mean(ious)) if ious else 0.0
def hausdorff_distance(pred: np.ndarray, target: np.ndarray) -> float:
"""
Hausdorff Distance: misura la distanza massima tra i bordi delle maschere.
Utile in medicina per valutare la precisione dei contorni.
"""
from scipy.spatial.distance import directed_hausdorff
pred_points = np.argwhere(pred)
target_points = np.argwhere(target)
if len(pred_points) == 0 or len(target_points) == 0:
return float('inf')
d1 = directed_hausdorff(pred_points, target_points)[0]
d2 = directed_hausdorff(target_points, pred_points)[0]
return max(d1, d2)
print("Esempio metriche:")
pred = torch.sigmoid(torch.randn(256, 256))
target = (torch.randn(256, 256) > 0).float()
iou = compute_iou(pred, target)
dice = dice_score(pred, target)
print(f"IoU: {iou:.3f} | Dice: {dice:.3f}")
2. U-Net: Rețeaua de Segmentare Medicală
U-Net (Ronneberger et al., 2015) a fost propus inițial pentru segmentare a imaginilor biomedicale. Arhitectura sa în formă de „U” cu sări peste conexiuni între codificator și decodor și a devenit șablonul dominant pentru orice sarcină de segmentare dens, de la pixeli medicali la hărți prin satelit, de la imagini industriale la scene în aer liber.
2.1 Arhitectura U-Net
Arhitectura este împărțită în trei părți:
- Encoder (cale de contracție): serie de blocuri convoluționale + pooling maxim care reduc rezoluția și măresc canalele, extragând caracteristici bogate din punct de vedere semantic, dar imprecise din punct de vedere spațial
- Gâtul de sticlă: cel mai adânc bloc, funcționează la cea mai mică rezoluție
- Decodor (cale de expansiune): serie de supraeșantionare + conversie care restabilește rezoluția originală, concatenând hărțile caracteristicilor codificatorului prin ignorarea conexiunilor pentru a recupera detaliile spațiale
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""Blocco base U-Net: Conv-BN-ReLU-Conv-BN-ReLU."""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.double_conv(x)
class DownBlock(nn.Module):
"""Encoder block: MaxPool2d + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.maxpool_conv(x)
class UpBlock(nn.Module):
"""Decoder block: Upsample + concatenazione skip + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2,
kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# Padding se le dimensioni non coincidono esattamente
diff_h = x2.size(2) - x1.size(2)
diff_w = x2.size(3) - x1.size(3)
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2,
diff_h // 2, diff_h - diff_h // 2])
# Skip connection: concatena feature encoder + decoder
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
"""
U-Net originale per segmentazione binaria o multi-classe.
Architettura:
Input -> [64] -> [128] -> [256] -> [512] -> [1024] (bottleneck)
-> [512] -> [256] -> [128] -> [64] -> Output
Le frecce verso il basso sono encoder (+ maxpool)
Le frecce verso l'alto sono decoder (+ skip connections)
"""
def __init__(self, in_channels: int = 1, num_classes: int = 1,
features: list[int] = [64, 128, 256, 512], bilinear: bool = True):
super().__init__()
self.in_conv = DoubleConv(in_channels, features[0])
# Encoder
self.downs = nn.ModuleList([
DownBlock(features[i], features[i+1])
for i in range(len(features) - 1)
])
# Bottleneck
factor = 2 if bilinear else 1
self.bottleneck = DownBlock(features[-1], features[-1] * 2 // factor)
# Decoder
self.ups = nn.ModuleList([
UpBlock(features[-1] * 2 // factor + features[-(i+1)],
features[-(i+2)] if i < len(features)-1 else features[0],
bilinear)
for i in range(len(features))
])
# Semplifichiamo con lista esplicita
self.ups = nn.ModuleList([
UpBlock(1024, 512 // factor, bilinear),
UpBlock(512, 256 // factor, bilinear),
UpBlock(256, 128 // factor, bilinear),
UpBlock(128, 64, bilinear),
])
self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encoder path (salva skip connections)
x1 = self.in_conv(x)
x2 = self.downs[0](x1)
x3 = self.downs[1](x2)
x4 = self.downs[2](x3)
# Bottleneck
x5 = self.bottleneck(x4)
# Decoder path (usa skip connections)
x = self.ups[0](x5, x4)
x = self.ups[1](x, x3)
x = self.ups[2](x, x2)
x = self.ups[3](x, x1)
return self.out_conv(x)
# Test architettura
model = UNet(in_channels=3, num_classes=1)
x = torch.randn(2, 3, 256, 256)
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")
# Input: torch.Size([2, 3, 256, 256]) -> Output: torch.Size([2, 1, 256, 256])
total_params = sum(p.numel() for p in model.parameters())
print(f"Parametri: {total_params:,}")
2.2 Antrenamentul U-Net cu pierderea zarurilor
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""
Dice Loss per segmentazione binaria.
Gestisce naturalmente lo sbilanciamento di classe tipico delle immagini mediche
(es. 95% sfondo, 5% lesione).
"""
def __init__(self, smooth: float = 1.0):
super().__init__()
self.smooth = smooth
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Applica sigmoid per ottenere probabilità
preds = torch.sigmoid(pred_logits)
# Flatten per calcolo efficiente
preds_flat = preds.view(-1)
targets_flat = targets.view(-1)
intersection = (preds_flat * targets_flat).sum()
dice = (2.0 * intersection + self.smooth) / (
preds_flat.sum() + targets_flat.sum() + self.smooth
)
return 1.0 - dice # loss = 1 - Dice (minimizzare)
class CombinedLoss(nn.Module):
"""
Combinazione BCE + Dice: il compromesso migliore per segmentazione medica.
BCE: ottimizza ogni pixel individualmente
Dice: ottimizza l'overlap globale tra predizione e ground truth
"""
def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
bce_loss = self.bce(pred_logits, targets.float())
dice_loss = self.dice(pred_logits, targets.float())
return self.bce_weight * bce_loss + self.dice_weight * dice_loss
def train_unet(
model: UNet,
train_loader,
val_loader,
num_epochs: int = 50,
learning_rate: float = 1e-4
) -> dict:
"""
Training completo di U-Net con:
- Combined BCE+Dice loss
- AdamW + CosineAnnealingLR
- Early stopping su Dice score di validazione
- Checkpoint del modello migliore
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs, eta_min=1e-6
)
history = {'train_loss': [], 'val_loss': [], 'val_dice': []}
best_dice = 0.0
patience = 15
no_improve = 0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
loss = criterion(pred_logits, masks)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Validation
model.eval()
val_loss = 0.0
val_dice_scores = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
val_loss += criterion(pred_logits, masks).item()
preds = torch.sigmoid(pred_logits)
for p, m in zip(preds, masks):
val_dice_scores.append(dice_score(p, m))
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
avg_val_dice = sum(val_dice_scores) / len(val_dice_scores)
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['val_dice'].append(avg_val_dice)
if avg_val_dice > best_dice:
best_dice = avg_val_dice
torch.save(model.state_dict(), 'best_unet.pth')
no_improve = 0
else:
no_improve += 1
print(f"Epoch {epoch+1:2d}/{num_epochs} | "
f"Loss: {avg_train_loss:.4f}/{avg_val_loss:.4f} | "
f"Dice: {avg_val_dice:.4f} | Best: {best_dice:.4f}")
if no_improve >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
print(f"Training completato. Best Dice Score: {best_dice:.4f}")
return history
3. Segmentează orice model (SAM)
Meta AI a fost lansat SAM (Kirillov et al., 2023) cu scopul ambițios de construi un model de segmentare generalist: un model antrenat pe 1 miliard de măști care poate segmenta nimic in orice imagine folosind prompturi flexibile (faceți clic pe un punct, casetă de delimitare, text). SAM2 (2024) a extins modelul și la videoclipuri.
3.1 Arhitectura SAM
SAM este compus din trei componente principale:
- Codificator de imagine: Vision Transformer (ViT-H cu parametri 632M) care generează imagini dense încorporate. Se rulează o singură dată pe imagine.
- Prompt Encoder: Codificați solicitări de diferite tipuri (puncte, casete, măști, text) în înglobări compatibile cu decodor.
- Decodor de mască: Transformator ușor care combină încorporarea imaginii + solicitări pentru a genera măști. Generați 3 măști de candidați cu scoruri de încredere.
# pip install segment-anything
# Download checkpoint: https://github.com/facebookresearch/segment-anything
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
def load_sam_model(
model_type: str = 'vit_h',
checkpoint_path: str = 'sam_vit_h_4b8939.pth',
device: str = 'cuda'
):
"""
Carica il modello SAM.
Tipi disponibili: 'vit_h' (default, max accuratezza), 'vit_l', 'vit_b' (più veloce)
"""
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=device)
return sam
def segment_with_point_prompt(
sam_model,
image: np.ndarray,
point_coords: list[tuple[int, int]],
point_labels: list[int] # 1=foreground, 0=background
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Segmenta con prompt a punti.
Restituisce: (maschere, score, logits) - 3 proposte ordinate per score.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True # genera 3 maschere candidate
)
# Ordina per score decrescente
sorted_idx = np.argsort(scores)[::-1]
return masks[sorted_idx], scores[sorted_idx], logits[sorted_idx]
def segment_with_box_prompt(
sam_model,
image: np.ndarray,
box: tuple[int, int, int, int] # [x1, y1, x2, y2]
) -> tuple[np.ndarray, float]:
"""
Segmenta con prompt bounding box.
Il box definisce la regione di interesse da segmentare.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, _ = predictor.predict(
box=np.array([box]),
multimask_output=False # 1 sola maschera con box prompt
)
return masks[0], float(scores[0])
def automatic_segmentation(sam_model, image: np.ndarray) -> list[dict]:
"""
Segmentazione automatica: SAM segmenta TUTTO nell'immagine
senza nessun prompt. Usa una griglia di punti come seed.
"""
mask_generator = SamAutomaticMaskGenerator(
model=sam_model,
points_per_side=32, # griglia 32x32 = 1024 punti seed
pred_iou_thresh=0.88, # filtra maschere con IoU basso
stability_score_thresh=0.95, # filtra maschere instabili
crop_n_layers=1, # multi-crop per oggetti piccoli
crop_n_points_downscale_factor=2,
min_mask_region_area=100 # rimuovi regioni molto piccole
)
masks = mask_generator.generate(image)
# Ordina per area decrescente
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
print(f"SAM ha trovato {len(masks)} segmenti")
for i, mask in enumerate(masks[:5]):
print(f" Segmento {i+1}: area={mask['area']} "
f"score={mask['predicted_iou']:.3f}")
return masks
def visualize_sam_results(image: np.ndarray, masks: list[dict],
alpha: float = 0.4) -> np.ndarray:
"""Visualizza tutte le maschere SAM con colori random."""
result = image.copy()
np.random.seed(42)
for mask_info in masks:
mask = mask_info['segmentation'] # bool array [H, W]
color = np.random.randint(50, 255, 3)
overlay = result.copy()
overlay[mask] = color
result = cv2.addWeighted(result, 1 - alpha, overlay, alpha, 0)
# Contorno
contours, _ = cv2.findContours(
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(result, contours, -1, color.tolist(), 2)
return result
# Esempio d'uso
sam = load_sam_model('vit_b', 'sam_vit_b_01ec64.pth') # versione più leggera
image = cv2.imread('image.jpg')
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Segmenta con un click (punto foreground)
masks, scores, _ = segment_with_point_prompt(
sam, image_rgb,
point_coords=[(320, 240)], # click al centro dell'oggetto
point_labels=[1] # 1 = foreground
)
best_mask = masks[0]
print(f"Maschera trovata con score: {scores[0]:.3f}")
3.2 SAM2 pentru video
# pip install sam2
# SAM2 rilasciato da Meta AI nell'agosto 2024
import torch
from sam2.build_sam import build_sam2_video_predictor
def segment_video_with_sam2(
video_path: str,
initial_frame: int,
initial_points: list[tuple[int, int]],
checkpoint: str = 'sam2_hiera_large.pt',
config: str = 'sam2_hiera_l.yaml'
) -> dict[int, np.ndarray]:
"""
Segmenta e traccia un oggetto attraverso i frame di un video.
Inizializza con punti sul primo frame, poi traccia automaticamente.
Returns:
Dict frame_idx -> maschera binaria [H, W]
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictor = build_sam2_video_predictor(config, checkpoint, device=device)
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16):
# Inizializza sul video
state = predictor.init_state(video_path=video_path)
predictor.reset_state(state)
# Aggiungi prompt sul frame iniziale
frame_idx, obj_ids, masks = predictor.add_new_points_or_box(
inference_state=state,
frame_idx=initial_frame,
obj_id=1,
points=np.array(initial_points),
labels=np.ones(len(initial_points), dtype=np.int32)
)
# Propaga su tutto il video
video_masks = {}
for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
mask = (masks[0][0] > 0.0).cpu().numpy()
video_masks[frame_idx] = mask
print(f"Segmentazione completata: {len(video_masks)} frame processati")
return video_masks
4. Studiu de caz: Segmentarea pulmonară din radiografii
Aplicăm U-Net segmentării plămânilor din radiografiile toracice folosind Setul de date cu raze X din comitatul Montgomery (138 radiografii cu măști de segmentare pulmonar adnotat manual de radiologi).
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
class LungXrayDataset(Dataset):
"""Dataset per segmentazione polmoni da radiografie (Montgomery CXR)."""
def __init__(self, image_dir: str, mask_dir: str, img_size: int = 512,
augment: bool = True):
self.image_paths = sorted(Path(image_dir).glob('*.png'))
self.mask_dir = Path(mask_dir)
self.img_size = img_size
if augment:
self.transform = A.Compose([
A.RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1,
rotate_limit=15, p=0.7),
A.OneOf([
A.GaussNoise(var_limit=(10, 50)),
A.GaussianBlur(blur_limit=3),
A.MedianBlur(blur_limit=3)
], p=0.3),
A.RandomBrightnessContrast(brightness_limit=0.2,
contrast_limit=0.2, p=0.5),
A.CLAHE(clip_limit=2, p=0.3), # Contrast Limited AHE per RX
A.Normalize(mean=[0.485], std=[0.229]), # Grayscale normalization
ToTensorV2()
])
else:
self.transform = A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=[0.485], std=[0.229]),
ToTensorV2()
])
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
img_path = self.image_paths[idx]
mask_path = self.mask_dir / img_path.name
# Carica immagine (grayscale)
image = np.array(Image.open(img_path).convert('L'))
mask = np.array(Image.open(mask_path).convert('L'))
# Binarizza maschera
mask = (mask > 127).astype(np.float32)
transformed = self.transform(image=image, mask=mask)
return transformed['image'], transformed['mask'].unsqueeze(0)
def run_lung_segmentation_pipeline():
"""Pipeline completa: dataset -> training -> valutazione -> salvataggio."""
# Data loading
train_dataset = LungXrayDataset(
'data/train/images', 'data/train/masks', augment=True
)
val_dataset = LungXrayDataset(
'data/val/images', 'data/val/masks', augment=False
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,
num_workers=4, pin_memory=True)
# Modello: U-Net per immagini grayscale
model = UNet(in_channels=1, num_classes=1, features=[32, 64, 128, 256])
# Training
history = train_unet(model, train_loader, val_loader, num_epochs=100)
# Valutazione finale
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('best_unet.pth', map_location=device))
model.eval()
all_dice = []
all_iou = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
preds = torch.sigmoid(model(images))
for p, m in zip(preds, masks):
all_dice.append(dice_score(p, m))
all_iou.append(compute_iou(p, m))
print(f"\n=== Risultati Finali ===")
print(f"Dice Score: {np.mean(all_dice):.4f} ± {np.std(all_dice):.4f}")
print(f"IoU: {np.mean(all_iou):.4f} ± {np.std(all_iou):.4f}")
# Risultati attesi per U-Net su Montgomery: Dice ~0.97, IoU ~0.94
5. Segmentează Anything Model 2: Segmentarea video zero-shot
SAM2 (Meta AI, iulie 2024) a extins SAM la secvențe video: dincolo pentru a segmenta obiecte în imagini statice cu solicitări interactive (punct, casetă, mască), SAM2 propagă automat masca de-a lungul cadrelor video datorită unui modul memorie. Este primul model care realizează în mod fiabil segmentarea zero-shot pe video.
SAM vs SAM2: diferențe cheie
| Caracteristici | SAM (2023) | SAM2 (2024) |
|---|---|---|
| Suport video | Nu (doar imagini) | Da (propagare în timp) |
| Modul de memorie | Absent | Bancă de memorie cu atenție încrucișată |
| Tip prompt | Punct, casetă, mască, text (prin CLIP) | Point, Box, Mask (+ urmărire video) |
| Viteză | ~50 ms/imagine (ViT-H) | ~44ms/cadru (Hiera-L), ~8ms (Hiera-T) |
| Date de antrenament | SA-1B (măști 1B) | SA-V (50.9K videoclipuri, 642K măști) |
| Multi-Obiect | Limitat | Da, urmărire simultană a mai multor obiecte |
import torch
import numpy as np
import cv2
from PIL import Image
# pip install git+https://github.com/facebookresearch/segment-anything-2.git
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
# ============================================================
# PARTE 1: SAM2 su singola immagine
# ============================================================
def sam2_image_segment(image_path: str,
point_coords: list[list[int]],
point_labels: list[int], # 1=foreground, 0=background
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> np.ndarray:
"""
Segmentazione con SAM2 su singola immagine.
point_coords: [[x1, y1], [x2, y2], ...] - punti prompt
point_labels: [1, 1, 0, ...] - 1=foreground, 0=background
Returns: maschera binaria [H, W] bool
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(model)
# Carica immagine
image = np.array(Image.open(image_path).convert('RGB'))
predictor.set_image(image)
# Predici maschera con prompt
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True, # 3 maschere con confidenze diverse
)
# Prendi la maschera con score più alto
best_idx = np.argmax(scores)
best_mask = masks[best_idx]
print(f"Maschera selezionata: score={scores[best_idx]:.3f}, "
f"area={best_mask.sum()} pixel")
return best_mask # [H, W] bool
def sam2_box_prompt(image_np: np.ndarray,
box: list[int],
predictor: SAM2ImagePredictor) -> np.ndarray:
"""
Segmentazione con prompt box (x1, y1, x2, y2).
Più preciso dei punti per oggetti con bordi definiti.
"""
predictor.set_image(image_np)
masks, scores, _ = predictor.predict(
box=np.array(box),
multimask_output=False, # Box prompt -> singola maschera ottimale
)
return masks[0] # [H, W] bool
# ============================================================
# PARTE 2: SAM2 su video - propagazione temporale
# ============================================================
def sam2_video_segment(video_dir: str,
frame_idx: int,
points: list[list[int]],
labels: list[int],
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> dict:
"""
SAM2 video predictor: segmenta un oggetto nel frame 'frame_idx'
e propaga la maschera automaticamente lungo tutto il video.
video_dir: cartella con frame del video (frame_*.jpg)
Returns: dict {frame_idx: {obj_id: mask}}
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device)
with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
# Inizializza predictor con la directory video
inference_state = predictor.init_state(video_path=video_dir)
# Aggiungi prompt nel frame di annotazione
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=1, # ID oggetto da trackare
points=np.array(points, dtype=np.float32),
labels=np.array(labels, dtype=np.int32),
)
# Propaga la segmentazione su tutto il video
all_masks = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
for obj_id, mask_logit in zip(out_obj_ids, out_mask_logits):
mask = (mask_logit > 0).squeeze().cpu().numpy()
if out_frame_idx not in all_masks:
all_masks[out_frame_idx] = {}
all_masks[out_frame_idx][int(obj_id)] = mask
return all_masks
# ============================================================
# PARTE 3: SAM2 come labeling tool automatizzato
# ============================================================
class SAM2AutoLabeler:
"""
Usa SAM2 per generare automaticamente maschere di training.
Riduce i costi di annotazione del 60-80% rispetto all'annotazione manuale.
Human-in-the-loop: un umano valida e corregge le predizioni SAM2.
"""
def __init__(self, checkpoint: str = 'sam2_hiera_base_plus.pt',
model_cfg: str = 'sam2_hiera_base_plus.yaml'):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
self.predictor = SAM2ImagePredictor(model)
def auto_label_from_yolo_boxes(self,
image_np: np.ndarray,
yolo_boxes: list[tuple],
min_score: float = 0.7) -> list[dict]:
"""
Genera maschere SAM2 usando bounding box di YOLO come prompt.
Workflow: YOLO rileva oggetti -> SAM2 affina con maschera pixel-perfect.
yolo_boxes: lista di (x1, y1, x2, y2, class_id, confidence)
Returns: lista di {box, class_id, mask, sam_score}
"""
self.predictor.set_image(image_np)
results = []
for x1, y1, x2, y2, class_id, conf in yolo_boxes:
if conf < 0.5:
continue
masks, scores, _ = self.predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=True,
)
best_idx = np.argmax(scores)
if scores[best_idx] < min_score:
continue
results.append({
'box': (x1, y1, x2, y2),
'class_id': class_id,
'mask': masks[best_idx],
'sam_score': float(scores[best_idx]),
'yolo_conf': float(conf)
})
return results
def save_masks_coco_format(self, results: list[dict],
image_id: int,
output_path: str) -> None:
"""Salva maschere in formato COCO per training Mask R-CNN."""
import json
from pycocotools import mask as coco_mask
annotations = []
for ann_id, r in enumerate(results):
binary_mask = r['mask'].astype(np.uint8)
rle = coco_mask.encode(np.asfortranarray(binary_mask))
rle['counts'] = rle['counts'].decode('utf-8')
area = float(np.sum(binary_mask))
x1, y1, x2, y2 = r['box']
annotations.append({
'id': ann_id,
'image_id': image_id,
'category_id': r['class_id'],
'segmentation': rle,
'area': area,
'bbox': [x1, y1, x2-x1, y2-y1],
'iscrowd': 0
})
with open(output_path, 'w') as f:
json.dump(annotations, f, indent=2)
6. Cele mai bune practici pentru segmentare
Recomandări cheie
- Alegerea pierderii: Pentru seturi de date neechilibrate (de exemplu, leziuni mici pe un fundal mare), utilizați Dice Loss sau Focal Loss în loc de BCE pur. ECB+Dice combinat este adesea cel mai bun compromis.
- Normalizare specifică domeniului: Pentru imagini medicale (scale de gri) utilizați statistici calculate pe setul de date specific, nu ImageNet. Pentru radiografii, preprocesarea CLAHE îmbunătățește semnificativ rezultatele.
- Mărirea conservatoare a datelor: În medicină, nu aplicați flip-uri verticale dacă nu are sens anatomic. Nu distorsionați prea mult: structurile anatomice au orientări precise.
- Rezoluție de intrare: U-Net este sensibil la rezoluție. Raze X: minim 512x512. Pentru detalii fine (histologie, citologie): 1024x1024 sau abordare prin crop.
- Post-procesare: Aplicați câmpuri aleatoare condiționale (CRF) sau operații morfologice (închidere, deschidere) pentru a ascuți marginile măștii.
- SAM pentru etichetare: Utilizați SAM pentru a accelera generarea de măști de antrenament (etichetare uman-in-the-loop), reducând costurile de adnotare cu 60-80%.
Greșeli comune
- Nu validați pe date de distribuție diferită: Modelele de segmentare medicală sunt notoriu fragile la schimbările de domeniu (scaner diferit, protocol, populație). Validați întotdeauna datele din diferite centre.
- Ignorați măștile de calitate scăzută: La antrenament, adnotările umane au variabilitate între observatori. Dacă este posibil, utilizați consensul cu mai mulți adnotatori sau pierderea în greutate pe baza încrederii adnotărilor.
- Folosește zarurile doar ca pierdere: Dice Loss este instabil cu loturi mici și are discontinuitate în gradient. Combinați întotdeauna cu BCE sau utilizați o variantă de pierdere generalizată a zarurilor.
- Neglijarea claselor rare: În segmentarea cu mai multe clase, clasele rare (câțiva pixeli) tind să fie ignorate de model. Utilizați pierderea ponderată în funcție de clasă sau supraeșantionarea imaginilor care conțin clase rare.
Concluzii
Am explorat principalele arhitecturi de segmentare și aplicațiile lor practice:
- U-Net: arhitectură codificator-decodor cu conexiuni skip, standard de facto pentru segmentarea medicală cu Dice ~0,97 pe radiografiile pulmonare
- Mască R-CNN: segmentare a instanțelor cu casetă de delimitare + mască pentru fiecare instanță, excelentă pentru scene naturale dense
- SAM și SAM2: Segmentare universală zero-shot cu solicitări interactive (SAM) și propagare video temporală (SAM2), revoluționară pentru etichetarea rapidă
- SAM2 ca instrument de etichetare automată: conducta YOLO+SAM2 care reduce costurile de adnotare cu 60-80%
- Dice Loss și combinat BCE+Dice: pierderile optime pentru seturi de date dezechilibrate cu regiuni mici
- Post-procesare: morfologie matematică și CRF pentru a rafina marginile măștii







