Anlamsal ve Örnek Segmentasyonu: U-Net, Mask R-CNN ve SAM
Görüntü segmentasyonu, görsel anlayışın en ayrıntılı düzeyini temsil eder: "Bu görüntüde bir tümör var" (sınıflandırma) veya "tümör bulundu" bilmek yerine bu alanda" (tespit) bilmek istiyoruz tam olarak hangi piksellerin tümöre ait olduğu. Bu mükemmel piksel hassasiyeti tıpta, robotik cerrahide ve otonom sürüşte temeldir ve endüstriyel kalite kontrolü.
Bu makalede segmentasyon için en önemli mimarileri inceleyeceğiz: U-Net (tıbbi segmentasyonda devrim yaratan model), Maske R-CNN (örneğin segmentasyon için altın standart) e SAM (Mümkün olanın sınırlarını yeniden tanımlayan Meta AI ile Her Şeyi Segmente Çıkarma Modeli).
Ne Öğreneceksiniz
- U-Net mimarisi: tıbbi segmentasyon için atlama bağlantılarına sahip kodlayıcı-kod çözücü
- Tıbbi veri kümeleri eğitimi ile PyTorch'ta sıfırdan U-Net uygulaması
- Maske R-CNN: sınırlayıcı kutular + ikili maskelerle örnek segmentasyonu
- Her Şeyi Segmente Ayırma Modeli (SAM): görsel komutlarla sıfır atışlı segmentasyon
- Değerlendirme metrikleri: Zar Puanı, IoU, Segmentasyon için Hassasiyet/Geri Çağırma
- İşlem sonrası teknikleri: CRF, matematiksel morfoloji
- Vaka çalışması: radyografilerden akciğer segmentasyonu (açık kaynak veri seti)
- Segmentasyon modellerinin üretimde devreye alınması
1. Segmentasyonun Temelleri
1.1 Segmentasyon Türleri
Segmentasyon Taksonomisi
| Tip | Örnekleri Ayırt Edin | Sıralama Arka Planı | Çıkışlar | Mimariler |
|---|---|---|---|---|
| Anlambilim | No | Si | Piksel başına etiket içeren HxW haritası | U-Net, DeepLabv3, SegFormer |
| Örneğin | Si | Hayır (sadece "şeyler") | Nesne için ikili maske | Maske R-CNN, SOLOv2, YOLACT |
| Panoptik | Evet ("şeyler" için) | Evet ("eşyalar" için) | Birleşik örnek+anlambilim haritası | Panoptik FPN, Mask2Former |
| İnteraktif | Evet (istemle) | Bu istemine bağlıdır | Tıklama/bbox odaklı maske | SAM, SAM2, ClickSEG |
1.2 Değerlendirme Metrikleri
Segmentasyon için örtüşen pikselleri piksel piksel ölçen belirli metrikler kullanılır tahmin edilen maske ile temel gerçek arasında:
import torch
import numpy as np
from typing import Union
def compute_iou(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> float:
"""
Intersection over Union per segmentazione binaria.
pred, target: tensori [H, W] o [B, H, W] con valori in [0,1]
"""
pred_binary = (pred >= threshold).bool()
target_binary = target.bool()
intersection = (pred_binary & target_binary).float().sum()
union = (pred_binary | target_binary).float().sum()
if union == 0:
return 1.0 # caso degenere: entrambe vuote
return float(intersection / union)
def dice_score(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5,
smooth: float = 1.0) -> float:
"""
Dice Score (F1 per segmentazione): 2*|X intersect Y| / (|X| + |Y|)
Preferito in ambito medico perchè meno sensibile agli sbilanciamenti.
Valore: 0 (peggio) -> 1 (perfetto)
"""
pred_binary = (pred >= threshold).float()
target_binary = target.float()
intersection = (pred_binary * target_binary).sum()
dice = (2.0 * intersection + smooth) / (pred_binary.sum() + target_binary.sum() + smooth)
return float(dice)
def compute_multiclass_miou(pred_logits: torch.Tensor, targets: torch.Tensor,
num_classes: int, ignore_index: int = 255) -> float:
"""
mIoU per segmentazione semantica multi-classe.
pred_logits: [B, C, H, W] - logit grezzi
targets: [B, H, W] - indici di classe 0..num_classes-1
"""
preds = pred_logits.argmax(dim=1) # [B, H, W]
ious = []
for cls in range(num_classes):
pred_cls = preds == cls
true_cls = targets == cls
valid = targets != ignore_index
pred_cls = pred_cls & valid
true_cls = true_cls & valid
intersection = (pred_cls & true_cls).sum().float()
union = (pred_cls | true_cls).sum().float()
if union > 0:
ious.append(float(intersection / union))
return float(np.mean(ious)) if ious else 0.0
def hausdorff_distance(pred: np.ndarray, target: np.ndarray) -> float:
"""
Hausdorff Distance: misura la distanza massima tra i bordi delle maschere.
Utile in medicina per valutare la precisione dei contorni.
"""
from scipy.spatial.distance import directed_hausdorff
pred_points = np.argwhere(pred)
target_points = np.argwhere(target)
if len(pred_points) == 0 or len(target_points) == 0:
return float('inf')
d1 = directed_hausdorff(pred_points, target_points)[0]
d2 = directed_hausdorff(target_points, pred_points)[0]
return max(d1, d2)
print("Esempio metriche:")
pred = torch.sigmoid(torch.randn(256, 256))
target = (torch.randn(256, 256) > 0).float()
iou = compute_iou(pred, target)
dice = dice_score(pred, target)
print(f"IoU: {iou:.3f} | Dice: {dice:.3f}")
2. U-Net: Tıbbi Segmentasyon Ağı
U-Net (Ronneberger ve diğerleri, 2015) başlangıçta segmentasyon için önerildi biyomedikal görsellerden oluşan bir koleksiyon. "U" şeklindeki mimarisi ile bağlantıları atla Kodlayıcı ve kod çözücü arasındadır ve her türlü segmentasyon görevi için baskın şablon haline gelmiştir tıbbi piksellerden uydu haritalarına, endüstriyel görüntülerden dış mekan sahnelerine kadar yoğun bir şekilde yer alıyor.
2.1 U-Net Mimarisi
Mimari üç bölüme ayrılmıştır:
- Kodlayıcı (daralma yolu): çözünürlüğü azaltan ve kanalları artıran, semantik olarak zengin ancak mekansal olarak kesin olmayan özellikler çıkaran bir dizi evrişimli blok + maksimum havuzlama
- Darboğaz: en derin blok, en düşük çözünürlükte çalışır
- Kod çözücü (genişletme yolu): Orijinal çözünürlüğü geri yükleyen bir dizi üst örnekleme + dönüşüm, uzamsal ayrıntıları kurtarmak için kodlayıcı özellik haritalarını atlama bağlantıları aracılığıyla birleştiriyor
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""Blocco base U-Net: Conv-BN-ReLU-Conv-BN-ReLU."""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.double_conv(x)
class DownBlock(nn.Module):
"""Encoder block: MaxPool2d + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.maxpool_conv(x)
class UpBlock(nn.Module):
"""Decoder block: Upsample + concatenazione skip + DoubleConv."""
def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2,
kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# Padding se le dimensioni non coincidono esattamente
diff_h = x2.size(2) - x1.size(2)
diff_w = x2.size(3) - x1.size(3)
x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2,
diff_h // 2, diff_h - diff_h // 2])
# Skip connection: concatena feature encoder + decoder
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class UNet(nn.Module):
"""
U-Net originale per segmentazione binaria o multi-classe.
Architettura:
Input -> [64] -> [128] -> [256] -> [512] -> [1024] (bottleneck)
-> [512] -> [256] -> [128] -> [64] -> Output
Le frecce verso il basso sono encoder (+ maxpool)
Le frecce verso l'alto sono decoder (+ skip connections)
"""
def __init__(self, in_channels: int = 1, num_classes: int = 1,
features: list[int] = [64, 128, 256, 512], bilinear: bool = True):
super().__init__()
self.in_conv = DoubleConv(in_channels, features[0])
# Encoder
self.downs = nn.ModuleList([
DownBlock(features[i], features[i+1])
for i in range(len(features) - 1)
])
# Bottleneck
factor = 2 if bilinear else 1
self.bottleneck = DownBlock(features[-1], features[-1] * 2 // factor)
# Decoder
self.ups = nn.ModuleList([
UpBlock(features[-1] * 2 // factor + features[-(i+1)],
features[-(i+2)] if i < len(features)-1 else features[0],
bilinear)
for i in range(len(features))
])
# Semplifichiamo con lista esplicita
self.ups = nn.ModuleList([
UpBlock(1024, 512 // factor, bilinear),
UpBlock(512, 256 // factor, bilinear),
UpBlock(256, 128 // factor, bilinear),
UpBlock(128, 64, bilinear),
])
self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Encoder path (salva skip connections)
x1 = self.in_conv(x)
x2 = self.downs[0](x1)
x3 = self.downs[1](x2)
x4 = self.downs[2](x3)
# Bottleneck
x5 = self.bottleneck(x4)
# Decoder path (usa skip connections)
x = self.ups[0](x5, x4)
x = self.ups[1](x, x3)
x = self.ups[2](x, x2)
x = self.ups[3](x, x1)
return self.out_conv(x)
# Test architettura
model = UNet(in_channels=3, num_classes=1)
x = torch.randn(2, 3, 256, 256)
y = model(x)
print(f"Input: {x.shape} -> Output: {y.shape}")
# Input: torch.Size([2, 3, 256, 256]) -> Output: torch.Size([2, 1, 256, 256])
total_params = sum(p.numel() for p in model.parameters())
print(f"Parametri: {total_params:,}")
2.2 Zar Kaybı ile U-Net Eğitimi
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""
Dice Loss per segmentazione binaria.
Gestisce naturalmente lo sbilanciamento di classe tipico delle immagini mediche
(es. 95% sfondo, 5% lesione).
"""
def __init__(self, smooth: float = 1.0):
super().__init__()
self.smooth = smooth
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Applica sigmoid per ottenere probabilità
preds = torch.sigmoid(pred_logits)
# Flatten per calcolo efficiente
preds_flat = preds.view(-1)
targets_flat = targets.view(-1)
intersection = (preds_flat * targets_flat).sum()
dice = (2.0 * intersection + self.smooth) / (
preds_flat.sum() + targets_flat.sum() + self.smooth
)
return 1.0 - dice # loss = 1 - Dice (minimizzare)
class CombinedLoss(nn.Module):
"""
Combinazione BCE + Dice: il compromesso migliore per segmentazione medica.
BCE: ottimizza ogni pixel individualmente
Dice: ottimizza l'overlap globale tra predizione e ground truth
"""
def __init__(self, bce_weight: float = 0.5, dice_weight: float = 0.5):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
def forward(self, pred_logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
bce_loss = self.bce(pred_logits, targets.float())
dice_loss = self.dice(pred_logits, targets.float())
return self.bce_weight * bce_loss + self.dice_weight * dice_loss
def train_unet(
model: UNet,
train_loader,
val_loader,
num_epochs: int = 50,
learning_rate: float = 1e-4
) -> dict:
"""
Training completo di U-Net con:
- Combined BCE+Dice loss
- AdamW + CosineAnnealingLR
- Early stopping su Dice score di validazione
- Checkpoint del modello migliore
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = CombinedLoss(bce_weight=0.5, dice_weight=0.5)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_epochs, eta_min=1e-6
)
history = {'train_loss': [], 'val_loss': [], 'val_dice': []}
best_dice = 0.0
patience = 15
no_improve = 0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
loss = criterion(pred_logits, masks)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Validation
model.eval()
val_loss = 0.0
val_dice_scores = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
pred_logits = model(images)
val_loss += criterion(pred_logits, masks).item()
preds = torch.sigmoid(pred_logits)
for p, m in zip(preds, masks):
val_dice_scores.append(dice_score(p, m))
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
avg_val_dice = sum(val_dice_scores) / len(val_dice_scores)
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['val_dice'].append(avg_val_dice)
if avg_val_dice > best_dice:
best_dice = avg_val_dice
torch.save(model.state_dict(), 'best_unet.pth')
no_improve = 0
else:
no_improve += 1
print(f"Epoch {epoch+1:2d}/{num_epochs} | "
f"Loss: {avg_train_loss:.4f}/{avg_val_loss:.4f} | "
f"Dice: {avg_val_dice:.4f} | Best: {best_dice:.4f}")
if no_improve >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
print(f"Training completato. Best Dice Score: {best_dice:.4f}")
return history
3. Her Şeyi Bölümlere Ayırma Modeli (SAM)
Meta AI yayınlandı SAM (Kirillov ve diğerleri, 2023) iddialı bir hedefle genel bir segmentasyon modeli oluşturun: 1 milyar maskeyle eğitilmiş bir model hangisi segmentlere ayırabilir herhangi bir şey in herhangi bir resim esnek istemleri kullanma (bir noktaya, sınırlayıcı kutuya, metne tıklayın). SAM2 (2024), modeli videolara da genişletti.
3.1 SAM mimarisi
SAM üç ana bileşenden oluşur:
- Görüntü Kodlayıcı: Yoğun görüntü yerleştirmeleri oluşturan Vision Transformer (632M parametreli ViT-H). Görüntü başına yalnızca bir kez çalışır.
- İstemi Kodlayıcı: Farklı türdeki istemleri (noktalar, kutular, maskeler, metin) kod çözücüyle uyumlu yerleştirmelere kodlayın.
- Maske Kod Çözücü: Görüntü yerleştirme + istemlerini birleştiren hafif transformatör, maske oluşturmayı sağlar. Güven puanlarına sahip 3 aday maskesi oluşturun.
# pip install segment-anything
# Download checkpoint: https://github.com/facebookresearch/segment-anything
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
def load_sam_model(
model_type: str = 'vit_h',
checkpoint_path: str = 'sam_vit_h_4b8939.pth',
device: str = 'cuda'
):
"""
Carica il modello SAM.
Tipi disponibili: 'vit_h' (default, max accuratezza), 'vit_l', 'vit_b' (più veloce)
"""
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=device)
return sam
def segment_with_point_prompt(
sam_model,
image: np.ndarray,
point_coords: list[tuple[int, int]],
point_labels: list[int] # 1=foreground, 0=background
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Segmenta con prompt a punti.
Restituisce: (maschere, score, logits) - 3 proposte ordinate per score.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True # genera 3 maschere candidate
)
# Ordina per score decrescente
sorted_idx = np.argsort(scores)[::-1]
return masks[sorted_idx], scores[sorted_idx], logits[sorted_idx]
def segment_with_box_prompt(
sam_model,
image: np.ndarray,
box: tuple[int, int, int, int] # [x1, y1, x2, y2]
) -> tuple[np.ndarray, float]:
"""
Segmenta con prompt bounding box.
Il box definisce la regione di interesse da segmentare.
"""
predictor = SamPredictor(sam_model)
predictor.set_image(image)
masks, scores, _ = predictor.predict(
box=np.array([box]),
multimask_output=False # 1 sola maschera con box prompt
)
return masks[0], float(scores[0])
def automatic_segmentation(sam_model, image: np.ndarray) -> list[dict]:
"""
Segmentazione automatica: SAM segmenta TUTTO nell'immagine
senza nessun prompt. Usa una griglia di punti come seed.
"""
mask_generator = SamAutomaticMaskGenerator(
model=sam_model,
points_per_side=32, # griglia 32x32 = 1024 punti seed
pred_iou_thresh=0.88, # filtra maschere con IoU basso
stability_score_thresh=0.95, # filtra maschere instabili
crop_n_layers=1, # multi-crop per oggetti piccoli
crop_n_points_downscale_factor=2,
min_mask_region_area=100 # rimuovi regioni molto piccole
)
masks = mask_generator.generate(image)
# Ordina per area decrescente
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
print(f"SAM ha trovato {len(masks)} segmenti")
for i, mask in enumerate(masks[:5]):
print(f" Segmento {i+1}: area={mask['area']} "
f"score={mask['predicted_iou']:.3f}")
return masks
def visualize_sam_results(image: np.ndarray, masks: list[dict],
alpha: float = 0.4) -> np.ndarray:
"""Visualizza tutte le maschere SAM con colori random."""
result = image.copy()
np.random.seed(42)
for mask_info in masks:
mask = mask_info['segmentation'] # bool array [H, W]
color = np.random.randint(50, 255, 3)
overlay = result.copy()
overlay[mask] = color
result = cv2.addWeighted(result, 1 - alpha, overlay, alpha, 0)
# Contorno
contours, _ = cv2.findContours(
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(result, contours, -1, color.tolist(), 2)
return result
# Esempio d'uso
sam = load_sam_model('vit_b', 'sam_vit_b_01ec64.pth') # versione più leggera
image = cv2.imread('image.jpg')
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Segmenta con un click (punto foreground)
masks, scores, _ = segment_with_point_prompt(
sam, image_rgb,
point_coords=[(320, 240)], # click al centro dell'oggetto
point_labels=[1] # 1 = foreground
)
best_mask = masks[0]
print(f"Maschera trovata con score: {scores[0]:.3f}")
3.2 Video için SAM2
# pip install sam2
# SAM2 rilasciato da Meta AI nell'agosto 2024
import torch
from sam2.build_sam import build_sam2_video_predictor
def segment_video_with_sam2(
video_path: str,
initial_frame: int,
initial_points: list[tuple[int, int]],
checkpoint: str = 'sam2_hiera_large.pt',
config: str = 'sam2_hiera_l.yaml'
) -> dict[int, np.ndarray]:
"""
Segmenta e traccia un oggetto attraverso i frame di un video.
Inizializza con punti sul primo frame, poi traccia automaticamente.
Returns:
Dict frame_idx -> maschera binaria [H, W]
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictor = build_sam2_video_predictor(config, checkpoint, device=device)
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16):
# Inizializza sul video
state = predictor.init_state(video_path=video_path)
predictor.reset_state(state)
# Aggiungi prompt sul frame iniziale
frame_idx, obj_ids, masks = predictor.add_new_points_or_box(
inference_state=state,
frame_idx=initial_frame,
obj_id=1,
points=np.array(initial_points),
labels=np.ones(len(initial_points), dtype=np.int32)
)
# Propaga su tutto il video
video_masks = {}
for frame_idx, obj_ids, masks in predictor.propagate_in_video(state):
mask = (masks[0][0] > 0.0).cpu().numpy()
video_masks[frame_idx] = mask
print(f"Segmentazione completata: {len(video_masks)} frame processati")
return video_masks
4. Vaka Çalışması: Radyografilerden Akciğer Segmentasyonu
U-Net'i kullanarak göğüs radyografilerinden akciğer segmentasyonuna uyguluyoruz. Montgomery İlçesi Röntgen Veri Kümesi (segmentasyon maskeli 138 radyografi pulmoner radyologlar tarafından manuel olarak açıklanmıştır).
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
class LungXrayDataset(Dataset):
"""Dataset per segmentazione polmoni da radiografie (Montgomery CXR)."""
def __init__(self, image_dir: str, mask_dir: str, img_size: int = 512,
augment: bool = True):
self.image_paths = sorted(Path(image_dir).glob('*.png'))
self.mask_dir = Path(mask_dir)
self.img_size = img_size
if augment:
self.transform = A.Compose([
A.RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1,
rotate_limit=15, p=0.7),
A.OneOf([
A.GaussNoise(var_limit=(10, 50)),
A.GaussianBlur(blur_limit=3),
A.MedianBlur(blur_limit=3)
], p=0.3),
A.RandomBrightnessContrast(brightness_limit=0.2,
contrast_limit=0.2, p=0.5),
A.CLAHE(clip_limit=2, p=0.3), # Contrast Limited AHE per RX
A.Normalize(mean=[0.485], std=[0.229]), # Grayscale normalization
ToTensorV2()
])
else:
self.transform = A.Compose([
A.Resize(img_size, img_size),
A.Normalize(mean=[0.485], std=[0.229]),
ToTensorV2()
])
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
img_path = self.image_paths[idx]
mask_path = self.mask_dir / img_path.name
# Carica immagine (grayscale)
image = np.array(Image.open(img_path).convert('L'))
mask = np.array(Image.open(mask_path).convert('L'))
# Binarizza maschera
mask = (mask > 127).astype(np.float32)
transformed = self.transform(image=image, mask=mask)
return transformed['image'], transformed['mask'].unsqueeze(0)
def run_lung_segmentation_pipeline():
"""Pipeline completa: dataset -> training -> valutazione -> salvataggio."""
# Data loading
train_dataset = LungXrayDataset(
'data/train/images', 'data/train/masks', augment=True
)
val_dataset = LungXrayDataset(
'data/val/images', 'data/val/masks', augment=False
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,
num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,
num_workers=4, pin_memory=True)
# Modello: U-Net per immagini grayscale
model = UNet(in_channels=1, num_classes=1, features=[32, 64, 128, 256])
# Training
history = train_unet(model, train_loader, val_loader, num_epochs=100)
# Valutazione finale
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('best_unet.pth', map_location=device))
model.eval()
all_dice = []
all_iou = []
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
preds = torch.sigmoid(model(images))
for p, m in zip(preds, masks):
all_dice.append(dice_score(p, m))
all_iou.append(compute_iou(p, m))
print(f"\n=== Risultati Finali ===")
print(f"Dice Score: {np.mean(all_dice):.4f} ± {np.std(all_dice):.4f}")
print(f"IoU: {np.mean(all_iou):.4f} ± {np.std(all_iou):.4f}")
# Risultati attesi per U-Net su Montgomery: Dice ~0.97, IoU ~0.94
5. Her Şeyi Segmentlere Ayırın Model 2: Sıfır Çekimli Video Segmentasyonu
SAM2 (Meta AI, Temmuz 2024) SAM'i video dizilerini kapsayacak şekilde genişletti: ötesinde Statik görüntülerdeki nesneleri etkileşimli komutlarla (nokta, kutu, maske) bölümlere ayırmak için, SAM2, bir modül sayesinde maskeyi video kareleri boyunca otomatik olarak yayar hafıza. Videoda sıfır atış segmentasyonunu güvenilir bir şekilde gerçekleştiren ilk modeldir.
SAM ve SAM2: Temel Farklılıklar
| Özellikler | SAM (2023) | SAM2 (2024) |
|---|---|---|
| Video desteği | Hayır (yalnızca görseller) | Evet (zaman yayılımı) |
| Bellek Modülü | Mevcut olmayan | Çerçeveler arası dikkat içeren hafıza bankası |
| İstem türü | Nokta, Kutu, Maske, Metin (CLIP aracılığıyla) | Nokta, Kutu, Maske (+ video izleme) |
| Hız | ~50ms/görüntü (ViT-H) | ~44ms/kare (Hiera-L), ~8ms (Hiera-T) |
| Eğitim Verileri | SA-1B (1B maskeleri) | SA-V (50,9K video, 642K maske) |
| Çoklu Nesne | Sınırlı | Evet, eşzamanlı çoklu nesne takibi |
import torch
import numpy as np
import cv2
from PIL import Image
# pip install git+https://github.com/facebookresearch/segment-anything-2.git
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
# ============================================================
# PARTE 1: SAM2 su singola immagine
# ============================================================
def sam2_image_segment(image_path: str,
point_coords: list[list[int]],
point_labels: list[int], # 1=foreground, 0=background
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> np.ndarray:
"""
Segmentazione con SAM2 su singola immagine.
point_coords: [[x1, y1], [x2, y2], ...] - punti prompt
point_labels: [1, 1, 0, ...] - 1=foreground, 0=background
Returns: maschera binaria [H, W] bool
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(model)
# Carica immagine
image = np.array(Image.open(image_path).convert('RGB'))
predictor.set_image(image)
# Predici maschera con prompt
masks, scores, logits = predictor.predict(
point_coords=np.array(point_coords),
point_labels=np.array(point_labels),
multimask_output=True, # 3 maschere con confidenze diverse
)
# Prendi la maschera con score più alto
best_idx = np.argmax(scores)
best_mask = masks[best_idx]
print(f"Maschera selezionata: score={scores[best_idx]:.3f}, "
f"area={best_mask.sum()} pixel")
return best_mask # [H, W] bool
def sam2_box_prompt(image_np: np.ndarray,
box: list[int],
predictor: SAM2ImagePredictor) -> np.ndarray:
"""
Segmentazione con prompt box (x1, y1, x2, y2).
Più preciso dei punti per oggetti con bordi definiti.
"""
predictor.set_image(image_np)
masks, scores, _ = predictor.predict(
box=np.array(box),
multimask_output=False, # Box prompt -> singola maschera ottimale
)
return masks[0] # [H, W] bool
# ============================================================
# PARTE 2: SAM2 su video - propagazione temporale
# ============================================================
def sam2_video_segment(video_dir: str,
frame_idx: int,
points: list[list[int]],
labels: list[int],
model_cfg: str = 'sam2_hiera_large.yaml',
checkpoint: str = 'sam2_hiera_large.pt') -> dict:
"""
SAM2 video predictor: segmenta un oggetto nel frame 'frame_idx'
e propaga la maschera automaticamente lungo tutto il video.
video_dir: cartella con frame del video (frame_*.jpg)
Returns: dict {frame_idx: {obj_id: mask}}
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device)
with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
# Inizializza predictor con la directory video
inference_state = predictor.init_state(video_path=video_dir)
# Aggiungi prompt nel frame di annotazione
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=1, # ID oggetto da trackare
points=np.array(points, dtype=np.float32),
labels=np.array(labels, dtype=np.int32),
)
# Propaga la segmentazione su tutto il video
all_masks = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
for obj_id, mask_logit in zip(out_obj_ids, out_mask_logits):
mask = (mask_logit > 0).squeeze().cpu().numpy()
if out_frame_idx not in all_masks:
all_masks[out_frame_idx] = {}
all_masks[out_frame_idx][int(obj_id)] = mask
return all_masks
# ============================================================
# PARTE 3: SAM2 come labeling tool automatizzato
# ============================================================
class SAM2AutoLabeler:
"""
Usa SAM2 per generare automaticamente maschere di training.
Riduce i costi di annotazione del 60-80% rispetto all'annotazione manuale.
Human-in-the-loop: un umano valida e corregge le predizioni SAM2.
"""
def __init__(self, checkpoint: str = 'sam2_hiera_base_plus.pt',
model_cfg: str = 'sam2_hiera_base_plus.yaml'):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = build_sam2(model_cfg, checkpoint, device=device)
self.predictor = SAM2ImagePredictor(model)
def auto_label_from_yolo_boxes(self,
image_np: np.ndarray,
yolo_boxes: list[tuple],
min_score: float = 0.7) -> list[dict]:
"""
Genera maschere SAM2 usando bounding box di YOLO come prompt.
Workflow: YOLO rileva oggetti -> SAM2 affina con maschera pixel-perfect.
yolo_boxes: lista di (x1, y1, x2, y2, class_id, confidence)
Returns: lista di {box, class_id, mask, sam_score}
"""
self.predictor.set_image(image_np)
results = []
for x1, y1, x2, y2, class_id, conf in yolo_boxes:
if conf < 0.5:
continue
masks, scores, _ = self.predictor.predict(
box=np.array([x1, y1, x2, y2]),
multimask_output=True,
)
best_idx = np.argmax(scores)
if scores[best_idx] < min_score:
continue
results.append({
'box': (x1, y1, x2, y2),
'class_id': class_id,
'mask': masks[best_idx],
'sam_score': float(scores[best_idx]),
'yolo_conf': float(conf)
})
return results
def save_masks_coco_format(self, results: list[dict],
image_id: int,
output_path: str) -> None:
"""Salva maschere in formato COCO per training Mask R-CNN."""
import json
from pycocotools import mask as coco_mask
annotations = []
for ann_id, r in enumerate(results):
binary_mask = r['mask'].astype(np.uint8)
rle = coco_mask.encode(np.asfortranarray(binary_mask))
rle['counts'] = rle['counts'].decode('utf-8')
area = float(np.sum(binary_mask))
x1, y1, x2, y2 = r['box']
annotations.append({
'id': ann_id,
'image_id': image_id,
'category_id': r['class_id'],
'segmentation': rle,
'area': area,
'bbox': [x1, y1, x2-x1, y2-y1],
'iscrowd': 0
})
with open(output_path, 'w') as f:
json.dump(annotations, f, indent=2)
6. Segmentasyon için En İyi Uygulamalar
Temel Öneriler
- Kayıp seçimi: Dengesiz veri kümeleri için (örneğin geniş bir arka planda küçük lezyonlar) saf BCE yerine Zar Kaybı veya Odak Kaybı kullanın. Birleşik ECB+Zar genellikle en iyi uzlaşmadır.
- Etki alanına özgü normalleştirme: Tıbbi görüntüler (gri tonlamalı) için ImageNet'i değil, belirli veri kümesinde hesaplanan istatistikleri kullanın. Radyografiler için CLAHE ön işlemesi sonuçları önemli ölçüde iyileştirir.
- Muhafazakar veri artırma: Tıpta anatomik olarak anlamlı değilse dikey çevirmeleri uygulamayın. Çok fazla deforme etmeyin: Anatomik yapıların kesin yönelimleri vardır.
- Giriş Çözünürlüğü: U-Net çözünürlüğe duyarlıdır. X-ışınları: minimum 512x512. İnce ayrıntılar için (histoloji, sitoloji): 1024x1024 veya kırpma yaklaşımı.
- İşlem sonrası: Maske kenarlarını keskinleştirmek için Koşullu Rastgele Alanlar (CRF) veya morfolojik işlemler (kapatma, açma) uygulayın.
- Etiketleme için SAM: Eğitim maskelerinin (döngüdeki insan etiketleme) oluşturulmasını hızlandırmak ve açıklama maliyetlerini %60-80 oranında azaltmak için SAM'i kullanın.
Yaygın Hatalar
- Farklı dağılıma sahip veriler üzerinde doğrulama yapmayın: Tıbbi segmentasyon modelleri, etki alanı değişimlerine (farklı tarayıcı, protokol, popülasyon) karşı oldukça kırılgandır. Her zaman farklı merkezlerden gelen verileri doğrulayın.
- Düşük kaliteli maskeleri dikkate almayın: Eğitimde, insan açıklamalarının gözlemciler arası değişkenliği vardır. Mümkünse, çoklu açıklamacı konsensüsü kullanın veya açıklama güvenine dayalı ağırlık kaybı kullanın.
- Zarı yalnızca kayıp olarak kullanın: Zar Kaybı küçük partilerde kararsızdır ve eğimde süreksizlik vardır. Her zaman BCE ile birleştirin veya Genelleştirilmiş Zar Kaybı varyantını kullanın.
- Nadir sınıfların ihmal edilmesi: Çok sınıflı segmentasyonda, nadir sınıflar (birkaç piksel) model tarafından göz ardı edilme eğilimindedir. Nadir sınıflar içeren görüntülerin sınıf ağırlıklı kaybını veya aşırı örneklemesini kullanın.
Sonuçlar
Ana segmentasyon mimarilerini ve bunların pratik uygulamalarını inceledik:
- U-Net: atlama bağlantılarına sahip kodlayıcı-kod çözücü mimarisi, akciğer radyografilerinde Dice ~0,97 ile tıbbi segmentasyon için fiili standart
- Maske R-CNN: her örnek için sınırlayıcı kutu + maske ile örnek segmentasyonu, yoğun doğal sahneler için idealdir
- SAM ve SAM2: Hızlı etiketleme için devrim niteliğinde, etkileşimli istemler (SAM) ve geçici video yayılımı (SAM2) ile evrensel sıfır atış segmentasyonu
- Otomatik etiketleme aracı olarak SAM2: Ek açıklama maliyetlerini %60-80 oranında azaltan YOLO+SAM2 hattı
- Zar Kaybı ve Kombine BCE+Zar: küçük bölgelere sahip dengesiz veri kümeleri için optimum kayıplar
- İşlem sonrası: maske kenarlarını iyileştirmek için matematiksel morfoloji ve CRF







