Sémantická a instanční segmentace: U-Net, maska R-CNN a SAM
Image segmentation represents the most granular level of visual understanding: instead of knowing "there is a tumor in this image" (classification) or "the tumor is found v této oblasti“ (detekce), chceme vědět přesně které pixely patří k nádoru. Tato dokonalá přesnost pixelů je základem v medicíně, robotické chirurgii a autonomním řízení a průmyslová kontrola kvality.
V tomto článku prozkoumáme nejdůležitější architektury pro segmentaci: U-Net (model, který způsobil revoluci v segmentaci medicíny), Maska R-CNN (zlatý standard například segmentace) e SAM (Segment Anything Model od Meta AI, která nově definovala hranice možného).
Co se naučíte
- Architektura U-Net: kodér-dekodér s přeskočením připojení pro lékařskou segmentaci
- Implementace U-Net od nuly v PyTorch se školením o lékařských datových sadách
- Maska R-CNN: segmentace instancí s ohraničujícími rámečky + binární masky
- Segment Anything Model (SAM): Segmentace typu zero-shot s vizuálními výzvami
- Hodnotící metriky: Skóre v kostce, IoU, Precision/Recall pro segmentaci
- Techniky následného zpracování: CRF, matematická morfologie
- Případová studie: segmentace plic z rentgenových snímků (open source dataset)
- Nasazení segmentačních modelů ve výrobě
1. Základy segmentace
1.1 Typy segmentace
Taxonomie segmentace
| Typ | Rozlišujte instance | Pozadí hodnocení | Výstupy | Architektury |
|---|---|---|---|---|
| Sémantika | No | Si | Mapa VxŠ s popisky na pixel | U-Net, DeepLabv3, SegFormer |
| Například | Si | Ne (jen "věci") | Binární maska pro objekt | Maska R-CNN, SOLOv2, YOLACT |
| Panoptikum | Ano (pro "věci") | Ano (pro "věci") | Jednotná mapa instance+sémantiky | Panoptic FPN, Mask2Former |
| Interaktivní | Ano (s výzvou) | Záleží na výzvě | Maska řízená kliknutím/bboxem | SAM, SAM2, ClickSEG |
1.2 Metriky hodnocení
Pro segmentaci se používají specifické metriky, které měří překrytí pixel po pixelu mezi předpovězenou maskou a základní pravdou:
import torch
import numpy as np
from typing import Union
def compute_iou(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> float:
"""
Intersection over Union per segmentazione binaria.
pred, target: tensori [H, W] o [B, H, W] con valori in [0,1]
"""
pred_binary = (pred >= threshold).bool()
target_binary = target.bool()
intersection = (pred_binary & target_binary).float().sum()
union = (pred_binary | target_binary).float().sum()
if union == 0:
return 1.0 # caso degenere: entrambe vuote
return float(intersection / union)
def dice_score(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5,
smooth: float = 1.0) -> float:
"""
Dice Score (F1 per segmentazione): 2*|X intersect Y| / (|X| + |Y|)
Preferito in ambito medico perchè meno sensibile agli sbilanciamenti.
Valore: 0 (peggio) -> 1 (perfetto)
"""
pred_binary = (pred >= threshold).float()
target_binary = target.float()
intersection = (pred_binary * target_binary).sum()
dice = (2.0 * intersection + smooth) / (pred_binary.sum() + target_binary.sum() + smooth)
return float(dice)
def compute_multiclass_miou(pred_logits: torch.Tensor, targets: torch.Tensor,
num_classes: int, ignore_index: int = 255) -> float:
"""
mIoU per segmentazione semantica multi-classe.
pred_logits: [B, C, H, W] - logit grezzi
targets: [B, H, W] - indici di classe 0..num_classes-1
"""
preds = pred_logits.argmax(dim=1) # [B, H, W]
ious = []
for cls in range(num_classes):
pred_cls = preds == cls
true_cls = targets == cls
valid = targets != ignore_index
pred_cls = pred_cls & valid
true_cls = true_cls & valid
intersection = (pred_cls & true_cls).sum().float()
union = (pred_cls | true_cls).sum().float()
if union > 0:
ious.append(float(intersection / union))
return float(np.mean(ious)) if ious else 0.0
def hausdorff_distance(pred: np.ndarray, target: np.ndarray) -> float:
"""
Hausdorff Distance: misura la distanza massima tra i bordi delle maschere.
Utile in medicina per valutare la precisione dei contorni.
"""
from scipy.spatial.distance import directed_hausdorff
pred_points = np.argwhere(pred)
target_points = np.argwhere(target)
if len(pred_points) == 0 or len(target_points) == 0:
return float('inf')
d1 = directed_hausdorff(pred_points, target_points)[0]
d2 = directed_hausdorff(target_points, pred_points)[0]
return max(d1, d2)
print("Esempio metriche:")
pred = torch.sigmoid(torch.randn(256, 256))
target = (torch.randn(256, 256) > 0).float()
iou = compute_iou(pred, target)
dice = dice_score(pred, target)
print(f"IoU: {iou:.3f} | Dice: {dice:.3f}")
2. U-Net: Síť pro lékařskou segmentaci
U-Net (Ronneberger et al., 2015) byl původně navržen pro segmentaci biomedicínských snímků. Jeho architektura ve tvaru "U" s přeskočit spojení mezi kodérem a dekodérem a stala se dominantní šablonou pro jakýkoli úkol segmentace husté, od lékařských pixelů po satelitní mapy, od průmyslových snímků po venkovní scény.
2.1 Architektura U-Net
Architektura je rozdělena do tří částí:
- Kodér (cesta kontrakce): série konvolučních bloků + maximální sdružování, které snižují rozlišení a zvyšují kanály, extrahují sémanticky bohaté, ale prostorově nepřesné vlastnosti
- Úzké místo: nejhlubší blok, pracuje s nejnižším rozlišením
- Dekodér (cesta rozšíření): série upsampling + conv, které obnovují původní rozlišení, zřetězení map funkcí kodéru přes přeskočení připojení k obnovení prostorových detailů
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""Blocco base U-Net: Conv-BN-ReLU-Conv-BN-ReLU."""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.double_conv(x)
class DownBlock(nn.Module):
"""Encoder block: MaxPool2d + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.maxpool_conv(x)
class UpBlock(nn.Module):
"""Decoder block: Upsample + concatenazione skip + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2,
kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# Padding se le dimensioni non coincidono esattamente
diff_h = x2.size(2) - x1.size(2)
diff_w = x2.size(3) - x1.size(3)
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2,
diff_h // 2, diff_h - diff_h // 2])
# Skip connection: concatena feature encoder + decoder
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
"""
U-Net originale per segmentazione binaria o multi-classe.
Architettura:
Input -> [64] -> [128] -> [256] -> [512] -> [1024] (bottleneck)
-> [512] -> [256] -> [128] -> [64] -> Output
Le frecce verso il basso sono encoder (+ maxpool)
Le frecce verso l'alto sono decoder (+ skip connections)
"""
def __init__(self, in_channels: int = 1, num_classes: int = 1,
features: list[int] = [64, 128, 256, 512], bilinear: bool = True):
super().__init__()
self.in_conv = DoubleConv(in_channels, features[0])
# Encoder
self.downs = nn.ModuleList([
DownBlock(features[i], features[i+1])
for i in range(len(features) - 1)
])
# Bottleneck
factor = 2 if bilinear else 1
self.bottleneck = DownBlock(features[-1], features[-1] * 2 // factor)
# Decoder
self.ups = nn.ModuleList([
UpBlock(features[-1] * 2 // factor + features[-(i+1)],
features[-(i+2)] if i < len(features)-1 else features[0],
bilinear)
for i in range(len(features))
])
# Semplifichiamo con lista esplicita
self.ups = nn.ModuleList([
UpBlock(1024, 512 // factor, bilinear),
UpBlock(512, 256 // factor, bilinear),
UpBlock(256, 128 // factor, bilinear),
UpBlock(128, 64, bilinear),
])
self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encoder path (salva skip connections)
x1 = self.in_conv(x)
x2 = self.downs[0](x1)
x3 = self.downs[1](x2)
x4 = self.downs[2](x3)
# Bottleneck
x5 = self.bottleneck(x4)
# Decoder path (usa skip connections)
x = self.ups[0](x5, x4)
x = self.ups[1](x, x3)
x = self.ups[2](x, x2)
x = self.ups[3](x, x1)
return self.out_conv(x)
# Test architettura
model = UNet(in_channels=3, num_classes=1)
x = torch.randn(2, 3, 256, 256)
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")
# Input: torch.Size([2, 3, 256, 256]) -> Output: torch.Size([2, 1, 256, 256])
total_params = sum(p.numel() for p in model.parameters())
print(f"Parametri: {total_params:,}")
2.2 Trénink U-Net se ztrátou kostek
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""
Dice Loss per segmentazione binaria.
Gestisce naturalmente lo sbilanciamento di classe tipico delle immagini mediche
(es. 95% sfondo, 5% lesione).
"""
def __init__(self, smooth: float = 1.0):
super().__init__()
self.smooth = smooth
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Applica sigmoid per ottenere probabilità
preds = torch.sigmoid(pred_logits)
# Flatten per calcolo efficiente
preds_flat = preds.view(-1)
targets_flat = targets.view(-1)
intersection = (preds_flat * targets_flat).sum()
dice = (2.0 * intersection + self.smooth) / (
preds_flat.sum() + targets_flat.sum() + self.smooth
)
return 1.0 - dice # loss = 1 - Dice (minimizzare)
class CombinedLoss(nn.Module):
"""
Combinazione BCE + Dice: il compromesso migliore per segmentazione medica.
BCE: ottimizza ogni pixel individualmente
Dice: ottimizza l'overlap globale tra predizione e ground truth
"""
def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
bce_loss = self.bce(pred_logits, targets.float())
dice_loss = self.dice(pred_logits, targets.float())
return self.bce_weight * bce_loss + self.dice_weight * dice_loss
def train_unet(
model: UNet,
train_loader,
val_loader,
num_epochs: int = 50,
learning_rate: float = 1e-4
) -> dict:
"""
Training completo di U-Net con:
- Combined BCE+Dice loss
- AdamW + CosineAnnealingLR
- Early stopping su Dice score di validazione
- Checkpoint del modello migliore
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs, eta_min=1e-6
)
history = {'train_loss': [], 'val_loss': [], 'val_dice': []}
best_dice = 0.0
patience = 15
no_improve = 0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
loss = criterion(pred_logits, masks)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Validation
model.eval()
val_loss = 0.0
val_dice_scores = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
val_loss += criterion(pred_logits, masks).item()
preds = torch.sigmoid(pred_logits)
for p, m in zip(preds, masks):
val_dice_scores.append(dice_score(p, m))
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
avg_val_dice = sum(val_dice_scores) / len(val_dice_scores)
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['val_dice'].append(avg_val_dice)
if avg_val_dice > best_dice:
best_dice = avg_val_dice
torch.save(model.state_dict(), 'best_unet.pth')
no_improve = 0
else:
no_improve += 1
print(f"Epoch {epoch+1:2d}/{num_epochs} | "
f"Loss: {avg_train_loss:.4f}/{avg_val_loss:.4f} | "
f"Dice: {avg_val_dice:.4f} | Best: {best_dice:.4f}")
if no_improve >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
print(f"Training completato. Best Dice Score: {best_dice:.4f}")
return history
3. Segmentace modelu Anything Model (SAM)
Meta AI byla vydána SAM (Kirillov et al., 2023) s ambiciózním cílem vybudovat obecný segmentační model: model trénovaný na 1 miliardě masek který umí segmentovat nic in jakýkoli obrázek pomocí flexibilních výzev (klikněte na bod, ohraničovací rámeček, text). SAM2 (2024) rozšířil model i na videa.
3.1 Architektura SAM
SAM se skládá ze tří hlavních komponent:
- Kodér obrázků: Vision Transformer (ViT-H s parametry 632M), který generuje husté vkládání obrazu. Spustí se pouze jednou na obrázek.
- Prompt Encoder: Zakódujte výzvy různých typů (body, rámečky, masky, text) do vložení kompatibilních s dekodérem.
- Dekodér masky: Lehký transformátor, který kombinuje vkládání obrázků + výzvy ke generování masek. Vygenerujte 3 kandidátní masky se skóre spolehlivosti.
# pip install segment-anything
# Download checkpoint: https://github.com/facebookresearch/segment-anything
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
def load_sam_model(
model_type: str = 'vit_h',
checkpoint_path: str = 'sam_vit_h_4b8939.pth',
device: str = 'cuda'
):
"""
Carica il modello SAM.
Tipi disponibili: 'vit_h' (default, max accuratezza), 'vit_l', 'vit_b' (più veloce)
"""
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=device)
return sam
def segment_with_point_prompt(
sam_model,
image: np.ndarray,
point_coords: list[tuple[int, int]],
point_labels: list[int] # 1=foreground, 0=background
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Segmenta con prompt a punti.
Restituisce: (maschere, score, logits) - 3 proposte ordinate per score.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True # genera 3 maschere candidate
)
# Ordina per score decrescente
sorted_idx = np.argsort(scores)[::-1]
return masks[sorted_idx], scores[sorted_idx], logits[sorted_idx]
def segment_with_box_prompt(
sam_model,
image: np.ndarray,
box: tuple[int, int, int, int] # [x1, y1, x2, y2]
) -> tuple[np.ndarray, float]:
"""
Segmenta con prompt bounding box.
Il box definisce la regione di interesse da segmentare.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, _ = predictor.predict(
box=np.array([box]),
multimask_output=False # 1 sola maschera con box prompt
)
return masks[0], float(scores[0])
def automatic_segmentation(sam_model, image: np.ndarray) -> list[dict]:
"""
Segmentazione automatica: SAM segmenta TUTTO nell'immagine
senza nessun prompt. Usa una griglia di punti come seed.
"""
mask_generator = SamAutomaticMaskGenerator(
model=sam_model,
points_per_side=32, # griglia 32x32 = 1024 punti seed
pred_iou_thresh=0.88, # filtra maschere con IoU basso
stability_score_thresh=0.95, # filtra maschere instabili
crop_n_layers=1, # multi-crop per oggetti piccoli
crop_n_points_downscale_factor=2,
min_mask_region_area=100 # rimuovi regioni molto piccole
)
masks = mask_generator.generate(image)
# Ordina per area decrescente
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
print(f"SAM ha trovato {len(masks)} segmenti")
for i, mask in enumerate(masks[:5]):
print(f" Segmento {i+1}: area={mask['area']} "
f"score={mask['predicted_iou']:.3f}")
return masks
def visualize_sam_results(image: np.ndarray, masks: list[dict],
alpha: float = 0.4) -> np.ndarray:
"""Visualizza tutte le maschere SAM con colori random."""
result = image.copy()
np.random.seed(42)
for mask_info in masks:
mask = mask_info['segmentation'] # bool array [H, W]
color = np.random.randint(50, 255, 3)
overlay = result.copy()
overlay[mask] = color
result = cv2.addWeighted(result, 1 - alpha, overlay, alpha, 0)
# Contorno
contours, _ = cv2.findContours(
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(result, contours, -1, color.tolist(), 2)
return result
# Esempio d'uso
sam = load_sam_model('vit_b', 'sam_vit_b_01ec64.pth') # versione più leggera
image = cv2.imread('image.jpg')
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Segmenta con un click (punto foreground)
masks, scores, _ = segment_with_point_prompt(
sam, image_rgb,
point_coords=[(320, 240)], # click al centro dell'oggetto
point_labels=[1] # 1 = foreground
)
best_mask = masks[0]
print(f"Maschera trovata con score: {scores[0]:.3f}")
3.2 SAM2 pro video
# pip install sam2
# SAM2 rilasciato da Meta AI nell'agosto 2024
import torch
from sam2.build_sam import build_sam2_video_predictor
def segment_video_with_sam2(
video_path: str,
initial_frame: int,
initial_points: list[tuple[int, int]],
checkpoint: str = 'sam2_hiera_large.pt',
config: str = 'sam2_hiera_l.yaml'
) -> dict[int, np.ndarray]:
"""
Segmenta e traccia un oggetto attraverso i frame di un video.
Inizializza con punti sul primo frame, poi traccia automaticamente.
Returns:
Dict frame_idx -> maschera binaria [H, W]
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictor = build_sam2_video_predictor(config, checkpoint, device=device)
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16):
# Inizializza sul video
state = predictor.init_state(video_path=video_path)
predictor.reset_state(state)
# Aggiungi prompt sul frame iniziale
frame_idx, obj_ids, masks = predictor.add_new_points_or_box(
inference_state=state,
frame_idx=initial_frame,
obj_id=1,
points=np.array(initial_points),
labels=np.ones(len(initial_points), dtype=np.int32)
)
# Propaga su tutto il video
video_masks = {}
for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
mask = (masks[0][0] > 0.0).cpu().numpy()
video_masks[frame_idx] = mask
print(f"Segmentazione completata: {len(video_masks)} frame processati")
return video_masks
4. Případová studie: Segmentace plic z rentgenových snímků
Aplikujeme U-Net na segmentaci plic z rentgenových snímků hrudníku pomocí Rentgenová datová sada okresu Montgomery (138 rentgenových snímků se segmentačními maskami plicní anotované ručně radiology).
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
class LungXrayDataset(Dataset):
"""Dataset per segmentazione polmoni da radiografie (Montgomery CXR)."""
def __init__(self, image_dir: str, mask_dir: str, img_size: int = 512,
augment: bool = True):
self.image_paths = sorted(Path(image_dir).glob('*.png'))
self.mask_dir = Path(mask_dir)
self.img_size = img_size
if augment:
self.transform = A.Compose([
A.RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1,
rotate_limit=15, p=0.7),
A.OneOf([
A.GaussNoise(var_limit=(10, 50)),
A.GaussianBlur(blur_limit=3),
A.MedianBlur(blur_limit=3)
], p=0.3),
A.RandomBrightnessContrast(brightness_limit=0.2,
contrast_limit=0.2, p=0.5),
A.CLAHE(clip_limit=2, p=0.3), # Contrast Limited AHE per RX
A.Normalize(mean=[0.485], std=[0.229]), # Grayscale normalization
ToTensorV2()
])
else:
self.transform = A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=[0.485], std=[0.229]),
ToTensorV2()
])
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
img_path = self.image_paths[idx]
mask_path = self.mask_dir / img_path.name
# Carica immagine (grayscale)
image = np.array(Image.open(img_path).convert('L'))
mask = np.array(Image.open(mask_path).convert('L'))
# Binarizza maschera
mask = (mask > 127).astype(np.float32)
transformed = self.transform(image=image, mask=mask)
return transformed['image'], transformed['mask'].unsqueeze(0)
def run_lung_segmentation_pipeline():
"""Pipeline completa: dataset -> training -> valutazione -> salvataggio."""
# Data loading
train_dataset = LungXrayDataset(
'data/train/images', 'data/train/masks', augment=True
)
val_dataset = LungXrayDataset(
'data/val/images', 'data/val/masks', augment=False
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,
num_workers=4, pin_memory=True)
# Modello: U-Net per immagini grayscale
model = UNet(in_channels=1, num_classes=1, features=[32, 64, 128, 256])
# Training
history = train_unet(model, train_loader, val_loader, num_epochs=100)
# Valutazione finale
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('best_unet.pth', map_location=device))
model.eval()
all_dice = []
all_iou = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
preds = torch.sigmoid(model(images))
for p, m in zip(preds, masks):
all_dice.append(dice_score(p, m))
all_iou.append(compute_iou(p, m))
print(f"\n=== Risultati Finali ===")
print(f"Dice Score: {np.mean(all_dice):.4f} ± {np.std(all_dice):.4f}")
print(f"IoU: {np.mean(all_iou):.4f} ± {np.std(all_iou):.4f}")
# Risultati attesi per U-Net su Montgomery: Dice ~0.97, IoU ~0.94
5. Segment Anything Model 2: Segmentace videa Zero-Shot
SAM2 (Meta AI, červenec 2024) rozšířila SAM na videosekvence: dále segmentovat objekty ve statických obrázcích pomocí interaktivních výzev (bod, rámeček, maska), SAM2 automaticky šíří masku podél video snímků díky modulu paměť. Je to první model, který na videu spolehlivě provádí segmentaci zero-shot.
SAM vs SAM2: Klíčové rozdíly
| Vlastnosti | SAM (2023) | SAM2 (2024) |
|---|---|---|
| Podpora videa | Ne (pouze obrázky) | Ano (šíření času) |
| Paměťový modul | Chybí | Paměťová banka s křížovou pozorností |
| Typ výzvy | Bod, pole, maska, text (přes CLIP) | Bod, krabice, maska (+ sledování videa) |
| Rychlost | ~50 ms/snímek (ViT-H) | ~44 ms/snímek (Hiera-L), ~8 ms (Hiera-T) |
| Údaje o školení | SA-1B (1B masky) | SA-V (50,9K videa, 642K masky) |
| Multi-Object | Omezený | Ano, simultánní sledování více objektů |
import torch
import numpy as np
import cv2
from PIL import Image
# pip install git+https://github.com/facebookresearch/segment-anything-2.git
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
# ============================================================
# PARTE 1: SAM2 su singola immagine
# ============================================================
def sam2_image_segment(image_path: str,
point_coords: list[list[int]],
point_labels: list[int], # 1=foreground, 0=background
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> np.ndarray:
"""
Segmentazione con SAM2 su singola immagine.
point_coords: [[x1, y1], [x2, y2], ...] - punti prompt
point_labels: [1, 1, 0, ...] - 1=foreground, 0=background
Returns: maschera binaria [H, W] bool
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(model)
# Carica immagine
image = np.array(Image.open(image_path).convert('RGB'))
predictor.set_image(image)
# Predici maschera con prompt
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True, # 3 maschere con confidenze diverse
)
# Prendi la maschera con score più alto
best_idx = np.argmax(scores)
best_mask = masks[best_idx]
print(f"Maschera selezionata: score={scores[best_idx]:.3f}, "
f"area={best_mask.sum()} pixel")
return best_mask # [H, W] bool
def sam2_box_prompt(image_np: np.ndarray,
box: list[int],
predictor: SAM2ImagePredictor) -> np.ndarray:
"""
Segmentazione con prompt box (x1, y1, x2, y2).
Più preciso dei punti per oggetti con bordi definiti.
"""
predictor.set_image(image_np)
masks, scores, _ = predictor.predict(
box=np.array(box),
multimask_output=False, # Box prompt -> singola maschera ottimale
)
return masks[0] # [H, W] bool
# ============================================================
# PARTE 2: SAM2 su video - propagazione temporale
# ============================================================
def sam2_video_segment(video_dir: str,
frame_idx: int,
points: list[list[int]],
labels: list[int],
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> dict:
"""
SAM2 video predictor: segmenta un oggetto nel frame 'frame_idx'
e propaga la maschera automaticamente lungo tutto il video.
video_dir: cartella con frame del video (frame_*.jpg)
Returns: dict {frame_idx: {obj_id: mask}}
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device)
with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
# Inizializza predictor con la directory video
inference_state = predictor.init_state(video_path=video_dir)
# Aggiungi prompt nel frame di annotazione
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=1, # ID oggetto da trackare
points=np.array(points, dtype=np.float32),
labels=np.array(labels, dtype=np.int32),
)
# Propaga la segmentazione su tutto il video
all_masks = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
for obj_id, mask_logit in zip(out_obj_ids, out_mask_logits):
mask = (mask_logit > 0).squeeze().cpu().numpy()
if out_frame_idx not in all_masks:
all_masks[out_frame_idx] = {}
all_masks[out_frame_idx][int(obj_id)] = mask
return all_masks
# ============================================================
# PARTE 3: SAM2 come labeling tool automatizzato
# ============================================================
class SAM2AutoLabeler:
"""
Usa SAM2 per generare automaticamente maschere di training.
Riduce i costi di annotazione del 60-80% rispetto all'annotazione manuale.
Human-in-the-loop: un umano valida e corregge le predizioni SAM2.
"""
def __init__(self, checkpoint: str = 'sam2_hiera_base_plus.pt',
model_cfg: str = 'sam2_hiera_base_plus.yaml'):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
self.predictor = SAM2ImagePredictor(model)
def auto_label_from_yolo_boxes(self,
image_np: np.ndarray,
yolo_boxes: list[tuple],
min_score: float = 0.7) -> list[dict]:
"""
Genera maschere SAM2 usando bounding box di YOLO come prompt.
Workflow: YOLO rileva oggetti -> SAM2 affina con maschera pixel-perfect.
yolo_boxes: lista di (x1, y1, x2, y2, class_id, confidence)
Returns: lista di {box, class_id, mask, sam_score}
"""
self.predictor.set_image(image_np)
results = []
for x1, y1, x2, y2, class_id, conf in yolo_boxes:
if conf < 0.5:
continue
masks, scores, _ = self.predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=True,
)
best_idx = np.argmax(scores)
if scores[best_idx] < min_score:
continue
results.append({
'box': (x1, y1, x2, y2),
'class_id': class_id,
'mask': masks[best_idx],
'sam_score': float(scores[best_idx]),
'yolo_conf': float(conf)
})
return results
def save_masks_coco_format(self, results: list[dict],
image_id: int,
output_path: str) -> None:
"""Salva maschere in formato COCO per training Mask R-CNN."""
import json
from pycocotools import mask as coco_mask
annotations = []
for ann_id, r in enumerate(results):
binary_mask = r['mask'].astype(np.uint8)
rle = coco_mask.encode(np.asfortranarray(binary_mask))
rle['counts'] = rle['counts'].decode('utf-8')
area = float(np.sum(binary_mask))
x1, y1, x2, y2 = r['box']
annotations.append({
'id': ann_id,
'image_id': image_id,
'category_id': r['class_id'],
'segmentation': rle,
'area': area,
'bbox': [x1, y1, x2-x1, y2-y1],
'iscrowd': 0
})
with open(output_path, 'w') as f:
json.dump(annotations, f, indent=2)
6. Nejlepší postupy pro segmentaci
Klíčová doporučení
- Volba ztráty: Pro nevyvážené soubory dat (např. malé léze na velkém pozadí) použijte místo čistého BCE ztrátu kostek nebo fokální ztrátu. Kombinace ECB+kostky je často tím nejlepším kompromisem.
- Normalizace specifická pro doménu: Pro lékařské snímky (stupně šedi) použijte statistiku vypočítanou pro konkrétní datovou sadu, nikoli ImageNet. U rentgenových snímků předzpracování CLAHE výrazně zlepšuje výsledky.
- Konzervativní rozšíření dat: V medicíně neaplikujte vertikální flipy, pokud to anatomicky nedává smysl. Nepřekrucujte příliš: anatomické struktury mají přesné orientace.
- Vstupní rozlišení: U-Net je citlivý na rozlišení. Rentgenové záření: minimálně 512x512. Pro jemné detaily (histologie, cytologie): 1024x1024 nebo oříznutí.
- Následné zpracování: Použijte podmíněná náhodná pole (CRF) nebo morfologické operace (zavření, otevření) k zaostření okrajů masky.
- SAM pro označení: Použijte SAM k urychlení generování cvičných masek (označení human-in-the-loop), čímž se sníží náklady na anotace o 60–80 %.
Časté chyby
- Neověřujte data různé distribuce: Lékařské segmentační modely jsou notoricky křehké vůči doménovým posunům (jiný skener, protokol, populace). Vždy ověřujte data z různých center.
- Ignorujte masky nízké kvality: Při tréninku mají lidské anotace variabilitu mezi pozorovateli. Pokud je to možné, použijte konsensus s více anotátory nebo snížení hmotnosti na základě spolehlivosti anotací.
- Kostky používejte pouze jako ztrátu: Ztráta kostek je nestabilní u malých dávek a má nespojitost v gradientu. Vždy kombinujte s BCE nebo použijte variantu Generalized Dice Loss.
- Zanedbání vzácných tříd: Při vícetřídní segmentaci bývají vzácné třídy (několik pixelů) modelem ignorovány. Použijte třídu váženou ztrátu nebo převzorkování obrázků obsahujících vzácné třídy.
Závěry
Prozkoumali jsme hlavní segmentační architektury a jejich praktické aplikace:
- U-Net: architektura kodéru a dekodéru se skip připojeními, de facto standard pro lékařskou segmentaci s kostkami ~0,97 na rentgenových snímcích plic
- Maska R-CNN: segmentace instancí s ohraničujícím rámečkem + maska pro každou instanci, skvělé pro husté přírodní scény
- SAM a SAM2: Univerzální segmentace zero-shot s interaktivními výzvami (SAM) a dočasným šířením videa (SAM2), revoluční pro rychlé označování
- SAM2 jako nástroj pro automatické označování: kanál YOLO+SAM2, který snižuje náklady na poznámky o 60–80 %
- Ztráta kostek a kombinované BCE+kostky: optimální ztráty pro nevyvážené datové sady s malými oblastmi
- Následné zpracování: matematická morfologie a CRF pro zjemnění okrajů masky







