Pokročilá klasifikace textu: Multi-label, Zero-shot a Few-shot
La Klasifikace textu a jeden z nejběžnějších úkolů NLP, ale v praxi jde daleko za pouhé „pozitivní nebo negativní“. Zpravodajský článek může být stejný politický, ekonomický a mezinárodní čas. Lístek podpory může patřit k několika kategorie současně. Dokument lze klasifikovat, aniž byste jej kdy viděli příklady této kategorie během školení.
V tomto článku se zabýváme klasifikací textu v celé její složitosti: multi-class (více vzájemně se vylučujících tříd), multi-label (více simultánních štítků), hierarchická klasifikace a klasifikace zero-shot (bez příkladů školení pro cílové třídy). Zahrnujeme implementace s italskými datovými sadami, Focal Loss pro nevyvážené datové sady a kompletní výrobní potrubí.
Toto je šestý článek ze série Moderní NLP: od BERT po LLM. Předpokládá znalost BERT a ekosystému HuggingFace.
Co se naučíte
- Rozdíl mezi binární, vícetřídní a víceznačkovou klasifikací — kdy co použít
- Jemné ladění BERT pro klasifikaci více tříd pomocí softmax a kompozitních metrik
- Víceznačková klasifikace s sigmoidem, BCEWithLogitsLoss a Focal Loss
- Vyladění prahových hodnot podle značek v multi-label s optimalizací F1
- Klasifikace Zero-shot s modely NLI (BART, DeBERTa-v3) a vlastními šablonami
- Několikanásobná klasifikace s SetFit: 8-64 příkladů na třídu
- Hierarchická klasifikace s plochým vs. shora dolů
- Správa nevyvážených souborů dat: vážení tříd, ztráta ohniska, převzorkování
- Multi-label metriky: Hammingova ztráta, mikro/makro F1, přesnost podmnožiny
- Klasifikační kanál připravený pro výrobu s ukládáním do mezipaměti a dávkovým odvozováním
1. Taxonomie klasifikace textu
Výběr správného typu klasifikace je kritickým prvním krokem. Volba ovlivňuje ztrátovou funkci, vyhodnocovací metriky a architekturu modelu.
Typy klasifikace textu: Průvodce výběrem
| Typ | Popis | Praktický příklad | Výstupní vrstva | Ztrátová funkce | Hlavní metrika |
|---|---|---|---|---|---|
| Binární | 2 vzájemně se vylučující třídy | Spam versus šunka, pozitivní versus negativní | Sigmoid (1) | BCEWithLogitsLoss | F1, AUC-ROC |
| Vícetřídní | N třídy, výběr 1 | Kategorie zpráv, jazyk textu | Softmax(N) | CrossEntropyLoss | Přesnost, makro F1 |
| Víceznačkové | N tříd, více aktivních současně | Značka článku, více emocí | Sigmoid (N) | BCEWithLogitsLoss | Hammingova ztráta, Micro F1 |
| Hierarchický | Třídy organizované v hierarchii | Kategorie produktu (Elektronika > TV > OLED) | Závisí | Hierarchická ztráta | F1 podle úrovně |
| Nulový výstřel | Třídy nikdy neviděli během tréninku | Směrování na libovolná témata | Skóre související s NLI | Ztráta tréninku NLI | F1, přesnost na třídu |
| Několik snímků (SetFit) | Několik příkladů na třídu (8–64) | Klasifikace podle domény | Vedoucí logistiky | Kontrastní + CE | Přesnost, F1 |
2. Vícetřídní klasifikace s BERT
V případě více tříd je pro každý příklad správná pouze jedna třída. Pojďme použít softmax jako aktivační funkce ve výstupní vrstvě e CrossEntropyLoss jako ztrátová funkce. BERT dosahuje nejmodernějších měřítek, jako jsou AG News (~95 %), yelp-full (~70 %), Odpovědi Yahoo (~77 %).
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TrainingArguments,
Trainer,
EarlyStoppingCallback
)
from datasets import load_dataset
import evaluate
import numpy as np
import torch
# AG News: classifica notizie in 4 categorie (dataset bilanciato)
# World, Sports, Business, Sci/Tech
dataset = load_dataset("ag_news")
print("Dataset AG News:", dataset)
# train: 120,000 esempi (30,000 per classe)
# test: 7,600 esempi
LABELS = ["World", "Sports", "Business", "Sci/Tech"]
num_labels = len(LABELS)
MODEL = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
def tokenize(examples):
return tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=128
)
tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])
tokenized.set_format("torch")
# Modello multi-class: softmax su 4 classi
model = AutoModelForSequenceClassification.from_pretrained(
MODEL,
num_labels=num_labels,
id2label={i: l for i, l in enumerate(LABELS)},
label2id={l: i for i, l in enumerate(LABELS)}
)
# Metriche composite per classificazione multi-class
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
# Softmax per probabilità
probs = torch.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).numpy()
max_prob = probs.max(axis=1).mean() # confidenza media
return {
"accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
"f1_macro": f1.compute(predictions=preds, references=labels, average="macro")["f1"],
"f1_weighted": f1.compute(predictions=preds, references=labels, average="weighted")["f1"],
"avg_confidence": float(max_prob)
}
args = TrainingArguments(
output_dir="./results/bert-agnews",
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
learning_rate=2e-5,
warmup_ratio=0.1,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1_macro",
fp16=True,
report_to="none",
seed=42
)
trainer = Trainer(
model=model, args=args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)
trainer.train()
# Risultati attesi su AG News test:
# Accuracy: ~94-95%, F1 macro: ~94-95%
# Inferenza production-ready
from transformers import pipeline
clf_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
texts = [
"European Central Bank raises interest rates by 0.5 percentage points",
"Juventus wins Champions League final against Real Madrid"
]
predictions = clf_pipeline(texts)
for text, pred in zip(texts, predictions):
print(f" '{text[:50]}...' -> {pred['label']} ({pred['score']:.3f})")
3. Klasifikace pro více značek
Ve více štítcích může mít každý příklad nula, jeden nebo více aktivních štítků. Zásadní změna oproti multi-class je ve výstupní vrstvě (sigmoid místo softmax) a ve ztrátové funkci (BCEWithLogitsLoss místo CrossEntropyLoss). Každá třída je považována za nezávislý binární problém.
3.1 Příprava souboru dat pro více značek
from datasets import Dataset
from transformers import AutoTokenizer
import torch
import numpy as np
# Dataset multi-label: articoli di news con tag multipli
data = {
"text": [
"La BCE alza i tassi di interesse per combattere l'inflazione europea",
"La Juventus batte il Milan 2-1 in una partita emozionante al Bernabeu",
"Apple presenta il nuovo iPhone con chip M4 e AI generativa avanzata",
"Il governo italiano approva la nuova legge fiscale tra polemiche politiche",
"La crisi climatica colpisce le economie dei paesi in via di sviluppo",
"Tesla annuncia nuovo stabilimento in Italia con 2000 posti di lavoro",
"La Commissione Europea propone nuove regole sull'intelligenza artificiale",
],
# Label: [economia, politica, sport, tecnologia, ambiente, italia]
"labels": [
[1, 0, 0, 0, 0, 0], # solo economia
[0, 0, 1, 0, 0, 0], # solo sport
[0, 0, 0, 1, 0, 0], # solo tecnologia
[1, 1, 0, 0, 0, 1], # economia + politica + italia
[1, 0, 0, 0, 1, 0], # economia + ambiente
[1, 0, 0, 1, 0, 1], # economia + tecnologia + italia
[1, 1, 0, 1, 0, 0], # economia + politica + tecnologia
]
}
LABELS = ["economia", "politica", "sport", "tecnologia", "ambiente", "italia"]
NUM_LABELS = len(LABELS)
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
def tokenize_multilabel(examples):
encoding = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=128
)
# Converti labels in float (richiesto da BCEWithLogitsLoss)
encoding["labels"] = [
[float(l) for l in label_list]
for label_list in examples["labels"]
]
return encoding
dataset = Dataset.from_dict(data)
tokenized = dataset.map(tokenize_multilabel, batched=True, remove_columns=["text"])
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "token_type_ids"])
# Analisi della distribuzione delle label
import pandas as pd
labels_df = pd.DataFrame(data["labels"], columns=LABELS)
print("\nDistribuzione label:")
for col in LABELS:
count = labels_df[col].sum()
print(f" {col}: {count}/{len(labels_df)} esempi ({100*count/len(labels_df):.0f}%)")
3.2 Multi-label model s vlastní ztrátou
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from torch import nn
import torch
import numpy as np
# =========================================
# Trainer con BCEWithLogitsLoss standard
# =========================================
class MultiLabelTrainer(Trainer):
"""Trainer personalizzato per multi-label con BCEWithLogitsLoss."""
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
# BCEWithLogitsLoss per multi-label
# Combina sigmoid + BCE in un'unica operazione numericamente stabile
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits.float(), labels.float().to(logits.device))
return (loss, outputs) if return_outputs else loss
# =========================================
# Focal Loss per dataset sbilanciati
# =========================================
class FocalLossMultiLabelTrainer(Trainer):
"""
Trainer con Focal Loss per gestire dataset multi-label sbilanciati.
Focal Loss riduce il peso degli esempi facili (ben classificati)
e aumenta l'attenzione sugli esempi difficili.
FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
gamma=2 e il valore standard (Lin et al. 2017)
"""
def __init__(self, *args, focal_gamma: float = 2.0, focal_alpha: float = 0.25, **kwargs):
super().__init__(*args, **kwargs)
self.focal_gamma = focal_gamma
self.focal_alpha = focal_alpha
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
# Focal Loss implementation
probs = torch.sigmoid(logits.float())
labels_float = labels.float().to(logits.device)
# Standard BCE term
bce_loss = nn.functional.binary_cross_entropy_with_logits(
logits.float(),
labels_float,
reduction='none'
)
# Focal modulation
pt = probs * labels_float + (1 - probs) * (1 - labels_float)
focal_weight = (1 - pt) ** self.focal_gamma
# Alpha balancing (peso diverso per classe positiva vs negativa)
alpha_t = self.focal_alpha * labels_float + (1 - self.focal_alpha) * (1 - labels_float)
focal_loss = (alpha_t * focal_weight * bce_loss).mean()
return (focal_loss, outputs) if return_outputs else focal_loss
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-multilingual-cased",
num_labels=NUM_LABELS,
problem_type="multi_label_classification"
)
# Metriche multi-label
def compute_multilabel_metrics(eval_pred):
logits, labels = eval_pred
probs = torch.sigmoid(torch.tensor(logits)).numpy()
predictions = (probs >= 0.5).astype(int)
# Hamming Loss: percentuale di label sbagliate (lower is better)
hamming = np.mean(predictions != labels)
# Subset accuracy: % di esempi con TUTTE le label corrette
exact_match = np.mean(np.all(predictions == labels, axis=1))
from sklearn.metrics import f1_score
micro_f1 = f1_score(labels, predictions, average='micro', zero_division=0)
macro_f1 = f1_score(labels, predictions, average='macro', zero_division=0)
return {
"hamming_loss": hamming,
"subset_accuracy": exact_match,
"micro_f1": micro_f1,
"macro_f1": macro_f1
}
args = TrainingArguments(
output_dir="./results/bert-multilabel",
num_train_epochs=5,
per_device_train_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True,
report_to="none"
)
# Usa FocalLoss per dataset sbilanciati, MultiLabelTrainer per bilanciati
trainer = FocalLossMultiLabelTrainer(
model=model, args=args,
train_dataset=tokenized,
compute_metrics=compute_multilabel_metrics,
focal_gamma=2.0,
focal_alpha=0.25
)
trainer.train()
3.3 Optimalizace prahu pro štítek
Výchozí prahová hodnota (0,5) není vždy optimální pro všechny štítky ve více štítcích. Je možné optimalizovat práh pro každý štítek zvlášť, aby se maximalizovala F1. To je důležité zejména u nevyvážených datových sad.
from sklearn.metrics import f1_score
import numpy as np
import torch
def find_optimal_thresholds(logits: np.ndarray, true_labels: np.ndarray,
thresholds=None, label_names=None) -> np.ndarray:
"""
Trova il threshold ottimale per ogni label che massimizza l'F1.
Usa il validation set per la ricerca del threshold.
"""
if thresholds is None:
thresholds = np.arange(0.05, 0.95, 0.05)
probs = 1 / (1 + np.exp(-logits)) # sigmoid
n_labels = logits.shape[1]
optimal_thresholds = np.zeros(n_labels)
print("Ricerca threshold ottimale per label:")
for label_idx in range(n_labels):
best_f1 = 0
best_threshold = 0.5
for threshold in thresholds:
preds = (probs[:, label_idx] >= threshold).astype(int)
f1 = f1_score(true_labels[:, label_idx], preds, zero_division=0)
if f1 > best_f1:
best_f1 = f1
best_threshold = threshold
optimal_thresholds[label_idx] = best_threshold
label_name = label_names[label_idx] if label_names else f"label_{label_idx}"
support = true_labels[:, label_idx].sum()
print(f" {label_name:<15s}: threshold={best_threshold:.2f}, F1={best_f1:.4f} (n={support})")
return optimal_thresholds
# Funzione di predict con thresholds personalizzati
def predict_multilabel(
texts: list,
model,
tokenizer,
thresholds: np.ndarray,
label_names: list,
batch_size: int = 32
) -> list:
"""Predizione multi-label con thresholds per-label ottimizzati."""
model.eval()
all_results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(batch, return_tensors='pt', truncation=True,
padding=True, max_length=128)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).numpy()
for sample_probs in probs:
sample_results = []
for j, (prob, threshold, label) in enumerate(
zip(sample_probs, thresholds, label_names)):
if prob >= threshold:
sample_results.append({"label": label, "probability": float(prob)})
all_results.append(sorted(sample_results, key=lambda x: x["probability"], reverse=True))
return all_results
# Esempio di utilizzo
# val_logits, val_labels = get_val_predictions(trainer, val_dataset)
# thresholds = find_optimal_thresholds(val_logits, val_labels, label_names=LABELS)
# predictions = predict_multilabel(texts, model, tokenizer, thresholds, LABELS)
4. Klasifikace nulového výstřelu
La klasifikace zero-shot umožňuje roztřídit texty do kategorií které modelka při tréninku nikdy neviděla. Využij modely trénované na Odvozování přirozeného jazyka (NLI): daný text a hypotéza, model predikuje, zda je hypotéza pravdivá (intailment), nepravdivá (rozpor) nebo nejistý (neutrální).
Proces: text se používá jako „předpoklad“, kategorie jako „hypotéza“ (např. „Tento text je o ekonomii“). Skóre entailment ukazuje jak moc text do této kategorie patří.
from transformers import pipeline
# Modelli NLI consigliati per zero-shot:
# - facebook/bart-large-mnli (inglese, ottimo per EN)
# - cross-encoder/nli-deberta-v3-large (più accurato, EN)
# - MoritzLaurer/mDeBERTa-v3-base-mnli-xnli (multilingua, include IT)
# - joeddav/xlm-roberta-large-xnli (multilingua alternativo)
# Zero-shot per italiano (multilingue)
classifier_it = pipeline(
"zero-shot-classification",
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
device=0 # usa GPU se disponibile
)
# Classificazione di un articolo italiano
text_it = "La BCE ha alzato i tassi di interesse di 25 punti base nella riunione di ottobre."
categories_it = ["economia", "politica", "sport", "tecnologia", "ambiente"]
result = classifier_it(
text_it,
candidate_labels=categories_it,
multi_label=False # True per multi-label simultaneo
)
print("Classificazione testo IT:")
for label, score in zip(result['labels'][:3], result['scores'][:3]):
print(f" {label}: {score:.3f}")
# Zero-shot multi-label
text_multi = "Tesla investe 2 miliardi in pannelli solari riducendo le emissioni CO2."
result_multi = classifier_it(
text_multi,
candidate_labels=categories_it,
multi_label=True
)
print("\nClassificazione multi-label:")
for label, score in zip(result_multi['labels'], result_multi['scores']):
if score > 0.3:
print(f" {label}: {score:.3f}")
# =========================================
# Template personalizzati per migliori risultati
# =========================================
classifier_en = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Default: "This example is {label}."
# Personalizzato: più descrittivo e preciso
text = "The Federal Reserve raised interest rates by 50 basis points."
# Confronto template
templates = {
"default": "This example is {}.",
"topic_specific": "This news article is about {}.",
"domain_specific": "This text is related to {} matters.",
}
for template_name, template in templates.items():
result = classifier_en(
text,
candidate_labels=["economics", "politics", "sports"],
hypothesis_template=template
)
print(f"\nTemplate '{template_name}': top={result['labels'][0]} ({result['scores'][0]:.3f})")
# Template per dominio legale
legal_text = "The court ruled in favor of the plaintiff in the patent infringement case."
legal_result = classifier_en(
legal_text,
candidate_labels=["intellectual property", "criminal law", "employment law"],
hypothesis_template="This legal document concerns {} law."
)
print(f"\nDominio legale: {legal_result['labels'][0]} ({legal_result['scores'][0]:.3f})")
5. Klasifikace několika snímků pomocí SetFit
SetFit (jemné doladění transformátoru vět) a technika, která to umožňuje trénovat přesné klasifikátory s velmi malým počtem příkladů na třídu (8-16 příkladů). Myšlenka je jednoduchá: nejprve natrénujte větný transformátor, aby rozpoznával dvojice podobné/nepodobné pomocí datové sady několika snímků, pak trénujte jednoduchou hlavu logistické klasifikace na výsledných vestavbách.
SetFit překonává standardní jemné ladění a GPT-3 několik snímků v mnoha benchmarcích s 8 příklady na třídu s použitím mnohem menšího modelu.
# pip install setfit
from setfit import SetFitModel, Trainer as SetFitTrainer, TrainingArguments as SetFitArgs
from datasets import Dataset, DatasetDict
import pandas as pd
# Dataset italiano few-shot: solo 8 esempi per classe
train_data = {
"text": [
# Economia (8 esempi)
"I tassi di interesse BCE aumentati dello 0.5%",
"Il PIL italiano cresce dell'1.2% nel terzo trimestre",
"La Borsa di Milano chiude in rialzo del 2.3%",
"L'inflazione scende al 2.8% grazie al calo dei prezzi",
"Fiat annuncia 3000 nuove assunzioni in Piemonte",
"Il deficit italiano supera il 3% del PIL europeo",
"Le esportazioni crescono verso i mercati emergenti",
"L'euro si rafforza rispetto al dollaro americano",
# Sport (8 esempi)
"La Juventus conquista la Coppa Italia ai rigori",
"Jannik Sinner vince il titolo ATP Finals a Torino",
"Ferrari ottiene la pole position a Silverstone",
"La nazionale azzurra batte la Germania 3-1",
"Il Milan acquista un attaccante per 80 milioni",
"La Roma pareggia 2-2 con l'Inter nel posticipo",
"Gianmarco Tamberi difende il titolo europeo nel salto in alto",
"La pallavolo italiana vince il Mondiale femminile",
# Tecnologia (8 esempi)
"OpenAI lancia il nuovo modello GPT con capacità multimodali",
"Apple presenta la nuova serie iPhone con chip 4nm",
"Google acquisisce una startup di intelligenza artificiale",
"Tesla aumenta la produzione di veicoli elettrici del 40%",
"Meta introduce nuovi filtri di privacy per gli utenti",
"Samsung annuncia chip con tecnologia 2nm nel 2025",
"Microsoft integra Copilot AI in Windows 12",
"Il 5G copre ora il 70% della popolazione italiana",
],
"label": [
0, 0, 0, 0, 0, 0, 0, 0, # economia = 0
1, 1, 1, 1, 1, 1, 1, 1, # sport = 1
2, 2, 2, 2, 2, 2, 2, 2 # tecnologia = 2
]
}
# Dataset di test più grande
test_data = {
"text": [
"La BCE mantiene invariati i tassi al 4.5%",
"L'Inter batte il Liverpool in Champions League",
"NVIDIA supera i 2 trilioni di capitalizzazione",
"Il governo italiano approva il Piano Mattei per l'Africa",
],
"label": [0, 1, 2, 0]
}
train_dataset = Dataset.from_dict(train_data)
test_dataset = Dataset.from_dict(test_data)
# Carica modello SetFit multilingue
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
labels=["economia", "sport", "tecnologia"]
)
# Training con SetFit
args = SetFitArgs(
batch_size=16,
num_epochs=1, # epoche per la testa di classificazione
num_iterations=20, # numero di coppie contrastive generate
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
report_to="none",
)
trainer = SetFitTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
metric="accuracy"
)
trainer.train()
# Inferenza
texts = [
"Il governo ha approvato la manovra finanziaria 2025",
"L'Inter ha battuto il Barcellona 3-1 in Champions League",
"OpenAI presenta il nuovo modello reasoning o3-pro"
]
predictions = model.predict(texts)
scores = model.predict_proba(texts)
print("\nPredizioni SetFit:")
for text, pred, prob in zip(texts, predictions, scores):
label_names = ["economia", "sport", "tecnologia"]
print(f" '{text[:45]}...' -> {label_names[pred]} ({max(prob):.3f})")
6. Hierarchická klasifikace
V mnoha scénářích reálného světa jsou kategorie organizovány do hierarchií. Článek lze klasifikovat jako „Technologie > AI > NLP“. Existují dva hlavní přístupy: byt (ignorovat hierarchii) e hierarchický (využít struktury).
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from typing import Dict, List, Tuple
# Esempio di gerarchia di categorie
HIERARCHY = {
"Economia": ["Mercati Finanziari", "Macroeconomia", "Commercio", "Lavoro"],
"Politica": ["Politica Nazionale", "Politica Estera", "Elezioni", "Legislazione"],
"Sport": ["Calcio", "Tennis", "Formula 1", "Atletica"],
"Tecnologia": ["AI", "Smartphone", "Cloud", "Cybersecurity"]
}
class HierarchicalClassifier:
"""
Classificatore gerarchico top-down.
Step 1: classifica nella categoria di primo livello
Step 2: classifica nella sottocategoria (secondo livello)
"""
def __init__(self, coarse_model_path: str, fine_models: Dict[str, str]):
"""
coarse_model: modella categorie di alto livello (Economia, Politica, ...)
fine_models: dict {coarse_category -> fine_model_path}
"""
self.tokenizer = AutoTokenizer.from_pretrained(coarse_model_path)
self.coarse_model = AutoModelForSequenceClassification.from_pretrained(coarse_model_path)
self.coarse_model.eval()
self.fine_models = {}
for cat, path in fine_models.items():
m = AutoModelForSequenceClassification.from_pretrained(path)
m.eval()
self.fine_models[cat] = m
def predict(self, text: str, return_scores: bool = False) -> Tuple[str, str, float]:
"""
Classifica in modo gerarchico.
Returns: (coarse_label, fine_label, confidence)
"""
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=256)
# Step 1: classificazione coarse
with torch.no_grad():
coarse_logits = self.coarse_model(**inputs).logits
coarse_probs = torch.softmax(coarse_logits, dim=-1)[0]
coarse_id = coarse_probs.argmax().item()
coarse_label = self.coarse_model.config.id2label[coarse_id]
coarse_score = coarse_probs[coarse_id].item()
# Step 2: classificazione fine (se disponibile per questa categoria)
fine_label = None
fine_score = None
if coarse_label in self.fine_models:
with torch.no_grad():
fine_logits = self.fine_models[coarse_label](**inputs).logits
fine_probs = torch.softmax(fine_logits, dim=-1)[0]
fine_id = fine_probs.argmax().item()
fine_label = self.fine_models[coarse_label].config.id2label[fine_id]
fine_score = fine_probs[fine_id].item()
return {
"coarse": coarse_label,
"fine": fine_label,
"coarse_confidence": coarse_score,
"fine_confidence": fine_score,
"full_path": f"{coarse_label} > {fine_label}" if fine_label else coarse_label
}
print("HierarchicalClassifier definito!")
7. Komplexní metriky pro multi-label
from sklearn.metrics import (
f1_score, precision_score, recall_score,
hamming_loss, accuracy_score, average_precision_score
)
import numpy as np
def multilabel_evaluation_report(y_true: np.ndarray, y_pred: np.ndarray,
y_proba: np.ndarray, label_names: list) -> dict:
"""Report completo per classificazione multi-label."""
print("=" * 65)
print("MULTI-LABEL CLASSIFICATION REPORT")
print("=" * 65)
# Metriche globali
hl = hamming_loss(y_true, y_pred)
sa = accuracy_score(y_true, y_pred) # subset accuracy
micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
print(f"\n{'Hamming Loss':<25s}: {hl:.4f} (lower is better)")
print(f"{'Subset Accuracy':<25s}: {sa:.4f} (all labels must match)")
print(f"{'Micro F1':<25s}: {micro_f1:.4f} (label-weighted)")
print(f"{'Macro F1':<25s}: {macro_f1:.4f} (unweighted)")
print(f"{'Weighted F1':<25s}: {f1_score(y_true, y_pred, average='weighted', zero_division=0):.4f}")
# AUC per label (richiede probabilità, non predizioni binarie)
if y_proba is not None:
try:
macro_auc = average_precision_score(y_true, y_proba, average='macro')
print(f"{'Macro AP (AUC)':<25s}: {macro_auc:.4f}")
except Exception:
pass
# Per ogni label
print("\nPer-label metrics:")
header = f"{'Label':<18s} {'Precision':>10s} {'Recall':>10s} {'F1':>8s} {'Support':>10s}"
print(header)
print("-" * 65)
for i, label in enumerate(label_names):
prec = precision_score(y_true[:, i], y_pred[:, i], zero_division=0)
rec = recall_score(y_true[:, i], y_pred[:, i], zero_division=0)
f1 = f1_score(y_true[:, i], y_pred[:, i], zero_division=0)
support = int(y_true[:, i].sum())
print(f"{label:<18s} {prec:>10.4f} {rec:>10.4f} {f1:>8.4f} {support:>10d}")
return {
"hamming_loss": hl, "subset_accuracy": sa,
"micro_f1": micro_f1, "macro_f1": macro_f1
}
8. Klasifikační potrubí připravené na výrobu
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import numpy as np
from functools import lru_cache
from typing import Union, List, Dict
import time
class ProductionClassifier:
"""
Classificatore production-ready con:
- Caching degli input tokenizzati
- Batch inference per efficienza
- Supporto multi-class e multi-label
- Monitoraggio latenza e confidenza
"""
def __init__(
self,
model_path: str,
task: str = "multi_class", # "multi_class" o "multi_label"
thresholds: np.ndarray = None,
max_length: int = 128,
batch_size: int = 32
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
self.task = task
self.label_names = list(self.model.config.id2label.values())
self.thresholds = thresholds if thresholds is not None else np.full(len(self.label_names), 0.5)
self.max_length = max_length
self.batch_size = batch_size
self._latencies = []
@torch.no_grad()
def predict(self, texts: Union[str, List[str]]) -> List[Dict]:
"""Predizione con monitoraggio latenza."""
if isinstance(texts, str):
texts = [texts]
start = time.perf_counter()
all_results = []
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i+self.batch_size]
inputs = self.tokenizer(
batch,
return_tensors='pt',
truncation=True,
padding=True,
max_length=self.max_length
).to(self.device)
outputs = self.model(**inputs)
logits = outputs.logits.cpu().numpy()
if self.task == "multi_class":
probs = np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)
for p in probs:
pred_id = p.argmax()
all_results.append({
"label": self.label_names[pred_id],
"score": float(p[pred_id]),
"all_scores": {name: float(score) for name, score in zip(self.label_names, p)}
})
else: # multi_label
probs = 1 / (1 + np.exp(-logits))
for p in probs:
labels = [
{"label": name, "score": float(score)}
for name, score, thr in zip(self.label_names, p, self.thresholds)
if score >= thr
]
all_results.append({"labels": sorted(labels, key=lambda x: -x["score"])})
latency_ms = (time.perf_counter() - start) * 1000
self._latencies.append(latency_ms)
return all_results
def get_stats(self) -> Dict:
"""Statistiche sulle latenze di inference."""
if not self._latencies:
return {}
return {
"avg_latency_ms": np.mean(self._latencies),
"p99_latency_ms": np.percentile(self._latencies, 99),
"total_predictions": len(self._latencies)
}
# Esempio di utilizzo
# clf = ProductionClassifier("./models/my-classifier", task="multi_label")
# results = clf.predict(["Il governo alza le tasse", "La Juve vince lo scudetto"])
# stats = clf.get_stats()
print("ProductionClassifier definito!")
9. Klasifikace pomocí generativních modelů (LLM Prompting)
S příchodem LLM je nyní možné provádět klasifikaci textu jednoduše prostřednictvím výzvy, bez jakéhokoli školení. Tento přístup a zvláště užitečné pro rychlé prototypování a nové nebo vzácné kategorie.
from transformers import pipeline
import json
# =========================================
# Classificazione con LLM tramite prompting
# =========================================
# Approccio 1: Con un modello instruction-following (es. Mistral-7B-Instruct)
def classify_with_llm(text: str, categories: list, model_pipeline) -> dict:
"""
Classificazione zero-shot con LLM instruction-following.
Il modello non richiede fine-tuning: usa la comprensione del linguaggio naturale.
"""
categories_str = ", ".join(categories)
prompt = f"""Classifica il seguente testo in UNA delle categorie: {categories_str}.
Testo: "{text}"
Rispondi SOLO con il nome della categoria, senza spiegazioni.
Categoria:"""
response = model_pipeline(
prompt,
max_new_tokens=20,
temperature=0.0, # deterministic
do_sample=False
)[0]['generated_text']
# Estrai la categoria dalla risposta
answer = response[len(prompt):].strip().split('\n')[0].strip()
# Valida che la risposta sia una categoria valida
for cat in categories:
if cat.lower() in answer.lower():
return {"label": cat, "method": "llm", "raw_answer": answer}
return {"label": "sconosciuto", "method": "llm", "raw_answer": answer}
# Approccio 2: Prompt con esempi (few-shot)
def classify_with_fewshot(text: str, categories: list, examples: list, model_pipeline) -> dict:
"""
Classificazione few-shot: fornisce esempi nel prompt per guidare il modello.
"""
examples_str = ""
for ex in examples[:3]: # massimo 3 esempi per non superare il context window
examples_str += f'Testo: "{ex["text"]}"\nCategoria: {ex["label"]}\n\n'
prompt = f"""Classifica testi nelle categorie: {", ".join(categories)}.
Esempi:
{examples_str}Testo: "{text}"
Categoria:"""
response = model_pipeline(
prompt, max_new_tokens=15, temperature=0.0, do_sample=False
)[0]['generated_text']
answer = response[len(prompt):].strip().split('\n')[0].strip()
return {"label": answer, "method": "few-shot-llm"}
# =========================================
# Confronto: zero-shot NLI vs LLM prompting vs BERT fine-tuned
# =========================================
comparison_table = [
{"method": "BERT fine-tuned", "F1": "0.95+", "speed": "veloce", "dati": "1000+ esempi", "costo": "basso"},
{"method": "SetFit (few-shot)", "F1": "0.85+", "speed": "veloce", "dati": "8-64 esempi", "costo": "basso"},
{"method": "Zero-shot NLI", "F1": "0.70+", "speed": "medio", "dati": "zero esempi", "costo": "basso"},
{"method": "LLM prompting (7B)", "F1": "0.75+", "speed": "lento", "dati": "zero esempi", "costo": "medio"},
{"method": "LLM few-shot (7B)", "F1": "0.82+", "speed": "lento", "dati": "3-10 esempi", "costo": "medio"},
{"method": "GPT-4 prompting (API)", "F1": "0.88+", "speed": "molto lento", "dati": "zero esempi", "costo": "alto"},
]
print("=== Confronto Metodi di Classificazione ===")
print(f"{'Metodo':<30s} {'F1':<10s} {'Velocita':<15s} {'Dati Richiesti':<18s} {'Costo'}")
print("-" * 85)
for row in comparison_table:
print(f"{row['method']:<30s} {row['F1']:<10s} {row['speed']:<15s} {row['dati']:<18s} {row['costo']}")
print("\nRaccomandazione: inizia con zero-shot NLI per validare il task,")
print("poi fine-tuna BERT se hai dati, oppure usa SetFit con pochi esempi annotati.")
Anti-Pattern: Použití přesnosti jako jediné metriky
S nevyváženými soubory dat (např. 95 % negativních, 5 % pozitivních) je model, který předpovídá vždy "negativní" získá 95% přesnost, ale je to k ničemu. Vždy používejte F1, přesnost a vyvolání pro binární a vícetřídní klasifikaci. V multi-labelu použijte Hamming loss a micro/macro F1. Nikdy neignorujte distribuci tříd v datové sadě.
Průvodce výběrem přístupu
| Scénář | Doporučený přístup | Čas nastavení |
|---|---|---|
| Pevné kategorie, spousta dat (>5 000) | Standardní jemné doladění BERT | hodiny |
| Pevné kategorie, málo dat (<100) | SetFit (pár snímků) | Zápis |
| Proměnné nebo nové kategorie | Zero-shot NLI + vlastní šablony | Instanty |
| Víceznačková, vyvážená datová sada | BERT + BCEWithLogitsLoss + ladění prahu | hodiny |
| Víceznačková, nevyvážená datová sada | Ohnisková ztráta + ladění prahu podle štítku | hodiny |
| Hierarchie kategorií | HierarchicalClassifier shora dolů | Dny |
| Rychlé prototypování | Potrubí s nulovým záběrem | Sekundy |
10. Model Benchmarks a průvodce výběrem
Výběr správného modelu pro klasifikaci textu závisí na typu úlohy, objem dat, požadavky na latenci a dostupný hardware. Zde je praktický benchmark hlavních přístupů ke standardním datovým sadám.
Srovnávací klasifikace textu (2024–2025)
| Úkoly | Datové sady | Model | Přesnost / F1 | Poznámky |
|---|---|---|---|---|
| Binární sentiment | SST-2 (EN) | DistilBERT doladěn | Acc 92,7 % | 6x rychlejší než BERT-base |
| Binární sentiment | SST-2 (EN) | RoBERTa-velký doladěný | Acc 96,4 % | Nejmodernější EN |
| Vícetřídní (6 tříd) | AG novinky | BERT-base jemně vyladěna | Acc 94,8 % | Standardní benchmark zpráv |
| Víceznačkové (90 kat.) | Reuters-21578 | RoBERTa + BCEloss | Micro-F1 89,2 % | 90 kategorií agentury Reuters |
| Nulový výstřel | Odpovědi Yahoo | BART-velký-MNLI | Acc 70,3 % | Žádné tréninkové údaje |
| Několik výstřelů (8 příkladů) | SST-2 (EN) | SetFit (MiniLM) | Acc 88,1 % | Zaznamenáno pouze 8 příkladů |
| Italský sentiment | SENTIPOLC 2016 | dbmdz BERT doladěn | F1 91,3 % | Nejlepší italský model |
Závěry a další kroky
Moderní klasifikace textu daleko přesahuje binární klasifikaci. Zero-shot, few-shot a multi-label jsou skutečné scénáře, které vyžadují specifické přístupy. S nástroji v tomto článku — od SetFit pro několik snímků po Focal Loss for nevyvážené datové sady, od zero-shot NLI po hierarchické klasifikátory – máte základy řešit jakýkoli scénář klasifikace textu ve výrobě.
Moderní série NLP pokračuje
- Předchozí: Rozpoznávání pojmenované entity — extrakce entity pomocí BERT
- Další: HuggingFace Transformers: Kompletní průvodce — trenér ekosystémů a API
- Článek 8: Lokální doladění LLM — LoRA a QLoRA na spotřebitelských GPU
- Článek 9: Sémantická podobnost — vložení vět a FAISS pro vyhledávání
- Článek 10: Monitorování NLP — detekce posunu a automatické přeškolení
- Související série: AI inženýrství/RAG — klasifikace zero-shot jako směrování v RAG
- Související série: Pokročilé hluboké učení — pokročilé klasifikační architektury







