Detekce posunu modelu a automatické přeškolení ve výrobě
Konečně jste nasadili svůj model do výroby. Metriky jsou výborné, tým je spokojený a zúčastněné strany tleskají. O několik týdnů později si někdo všimne, že předpovědi se zdají být méně přesné. O měsíc později je model jasně degradován. Vítejte u nejzáludnějšího problému ve strojovém učení ve výrobě: modelový drift.
Podle výzkumu společnosti Gartner, mimo 65 % ML modelů ve výrobě výrazně degraduje významné do 12 měsíců od nasazení, často aniž by si to týmy včas uvědomily. Data jsou ještě znepokojivější v maloobchodě a financích, kde se distribuce dat mění rychle reagovat na trendy na trhu, sezónnost a chování uživatelů.
V této příručce vytvoříme kompletní systém detekce driftu a automatické přeškolení: porozumíme různým typům driftu, implementujeme detektory s Evidently AI, NannyML a Alibi Detect, nakonfigurujeme statistické testy (KS, PSI, Chi-Square), integrujeme Prometheus a Grafana pro sledování kontinuální a vytvoříme automatické rekvalifikační kanály spouštěné výstrahami.
Co se naučíte
- Rozdíl mezi posunem dat, posunem konceptu, posunem vlastností a posunem označení
- Statistické testy k detekci driftu: KS test, PSI, Chi-Square, MMD
- Praktická implementace s Evidently AI, NannyML a Alibi Detect
- Monitorovací panel s Prometheus a Grafana
- Automatické upozornění a přeškolení pomocí MLflow
- Osvědčené postupy pro nízkorozpočtové MLO produkční úrovně
Proč je drift kritickým problémem
Skutečný svět není statický. Odrážela se data, která váš model viděl během tréninku konkrétní statistické rozložení, „snímek“ světa v daném okamžiku. Ale svět se stále mění: uživatelské návyky se vyvíjejí, trhy kolísají, systémy na vyšší úrovni změnit formát dat, dojde k neočekávaným událostem, jako jsou pandemie nebo hospodářské krize.
Základním problémem driftu je tichá degradace: model se zastaví být přesné, ale přesto vytvářet předpovědi bez technických chyb. Služba odpovídá s HTTP 200 protokoly neukazují žádné výjimky, ale rozhodnutí založená na těchto předpovědích ano čím dál tím víc se mýlit. Bez aktivního monitorovacího systému může tato degradace měsíce bez povšimnutí.
Ekonomický dopad nedetekovaného driftu
Degradovaný model detekce podvodů může nechat podvodné transakce neodhalit. Cenový systém, který se mění, může stát miliony v nekonkurenčních cenách. Modelka zhoršená predikce odchodu vede k retenčním kampaním promarněným na nesprávné zákazníky. Náklady na monitorování jsou vždy nižší než náklady na nezjištěný posun.
Taxonomie driftu: čtyři základní typy
Před implementací řešení je nezbytné porozumět Co je to driftování. Existují čtyři hlavní kategorie driftu, z nichž každá má různé příčiny a různé strategie detekce.
1. Posun dat (posun kovariát)
Il datový drift, také známý jako kovariátní posun, se stane, když rozložení vstupních vlastností P(X) se oproti tréninku mění, ale vztah mezi prvek a označení P(Y|X) zůstávají stabilní. Klasický příklad: model byl natrénován uživatelé určité věkové skupiny, ale produkt je přijat novou demografickou skupinou.
Posun dat je nejběžnějším typem a je nejsnáze zjistitelný, protože vyžaduje pouze monitorování distribuce vstupních funkcí, bez potřeby štítků. Lze jej detekovat také v v reálném čase, než výsledky ovlivní předpovědi.
2. Koncept Drift
Il koncept drift a záludnější: vztah P(Y|X) mezi funkcemi a štítky se změní, i když distribuce prvku X zůstane stabilní. Příklad: model Analýza sentimentu natrénovaná na tweetech z roku 2022 nerozumí žargonu roku 2025. Sémantika slov (X) se změnila, takže zobrazení X → Y je jiné.
Posun konceptu vyžaduje, aby byla základní pravda přímo detekována: musí být porovnána předpovědi se skutečnými štítky. Když přijdou pozdě (jako ve scénářích predikce odchodu s 90denními okny pozorování), používají se proxy metriky jako je předpovědní drift nebo rozdělení skóre pravděpodobnosti.
3. Drift funkcí
Il rys drift a podmnožinu datového posunu, který se týká specifikací kritické vlastnosti pro model. Ne všechny funkce mají stejný dopad: vlastnost s vysokou důležitostí, která se posunuje a je mnohem kritičtější než funkce s nízkou relevancí. Nástroje funkce důležitosti (SHAP, důležitost permutace) pomáhají stanovit priority monitorování.
4. Posun štítku (předchozí posun pravděpodobnosti)
Il drift štítku nastane, když distribuce cílových štítků P(Y) změnit. V binárním klasifikačním modelu (spam/nespam), pokud náhle dojde k 90 % zpráv je spam místo obvyklých 10 %, model je kalibrován pro jednu distribuci jiné a předpovědi budou zkreslené. Tento typ driftu je běžný ve scénářích s třídní nerovnováha proměnná v čase.
Souhrn typů driftů
- Posun data: P(X) se mění, P(Y|X) stabilní. Objevitelné bez štítku.
- Koncept Drift: P(Y|X) se změní. Vyžaduje štítek nebo proxy metriky.
- Vlastnosti driftu: Specifické vlastnosti se mění. Priorita založená na důležitosti.
- Posun štítku: P(Y) se změní. Sledovat distribuci předpovědí.
Statistické testy pro detekci driftu
Statistická detekce driftu je založena na srovnání dvou distribucí: distribuce referenční (školení nebo období stabilní produkce) a aktuální rozložení (monitorovací okno). Různé statistické testy mají různé charakteristiky citlivosti, interpretovatelnosti a výpočetních nákladů.
Kolmogorov-Smirnov test (KS)
Il KS test a nejpoužívanější pro spojité funkce. Změřte maximální vzdálenost mezi kumulativními distribučními funkcemi (CDF) dvou distribucí. Získaná p-hodnota označuje pravděpodobnost, že dva vzorky pocházejí ze stejné distribuce: nízká p-hodnota (typicky < 0,05) signalizuje statisticky významný posun.
Výhody: nepředpokládá specifické rozdělení (neparametrické), robustní, snadno proveditelné interpretovat vizuálně. Omezení: citlivý na distribuční ocasy, méně výkonný s malými vzorky může poskytnout falešně pozitivní výsledky s velkými soubory dat.
Index stability populace (PSI)
Il PSI a narozený v bankovním sektoru sledovat stabilitu rozdělení rizikových skóre. Rozdělí obě rozdělení do segmentů a vypočítá součet vážených rozdílů mezi proporcemi. Standardní výklad je:
- PSI < 0,1: žádná významná změna
- PSI 0,1 - 0,2: mírná změna, monitor
- PSI > 0,2: Významná změna, nutná akce
PSI je pro obchodní partnery velmi intuitivní a vztahuje se na obě spojité funkce (s diskretizací na decily) a kategorické. A zvláště oblíbené v modelech úvěrového hodnocení a odhalování podvodů.
Chí-kvadrát test
Il Chí-kvadrát test a základní test pro kategorické rysy. Porovnejte pozorované frekvence s očekávanými a vytváří p-hodnotu. A vhodné, když funkce mají omezený počet kategorií a vzorky jsou dostatečně velké (frekvence čekat > 5 pro každou kategorii). U prvků s vysokou mohutností se doporučuje seskupení vzácné kategorie.
Maximální střední nesrovnalost (MMD)
L'MMD a test založený na jádře, který měří vzdálenost mezi dvěma distribucemi v Hilbertově prostoru. Je zvláště výkonný pro detekci rozdílů ve strukturách multivariační a používá jej Alibi Detect pro posun tabulkových dat, obrázků a textu. Výhodou je, že nevyžaduje volbu lopatek ani diskretizační parametry.
Implementace s Evidently AI
Evidentně AI se stala standardní open-source knihovnou pro monitorování modelů ML v Pythonu s více než 20 miliony stažení. Nabízí předdefinované předvolby pro nejběžnější případy použití a integruje se s jakýmkoli orchestrátorem pracovních postupů.
# Installazione
pip install evidently
import pandas as pd
import numpy as np
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset, DataQualityPreset, ClassificationPreset
from evidently.metrics import (
DatasetDriftMetric,
DataDriftTable,
ColumnDriftMetric,
ColumnSummaryMetric
)
# --- Setup dati di riferimento e produzione ---
# Carica training data (reference)
reference_data = pd.read_parquet("data/training_features.parquet")
# Carica batch produzione ultimo mese
current_data = pd.read_parquet("data/production_batch_2025_02.parquet")
# Feature columns
feature_columns = [
"age", "tenure_months", "monthly_charges",
"total_charges", "num_support_tickets",
"contract_type", "payment_method"
]
# --- Report Data Drift ---
drift_report = Report(metrics=[
DatasetDriftMetric(), # overall drift summary
DataDriftTable(), # per-feature drift table
ColumnDriftMetric(column_name="monthly_charges"),
ColumnDriftMetric(column_name="contract_type"),
ColumnSummaryMetric(column_name="monthly_charges"),
])
drift_report.run(
reference_data=reference_data[feature_columns],
current_data=current_data[feature_columns]
)
# Salva report HTML interattivo
drift_report.save_html("reports/drift_report_2025_02.html")
# Estrai metriche programmaticamente
report_dict = drift_report.as_dict()
dataset_drift = report_dict["metrics"][0]["result"]
print(f"Dataset drift detected: {dataset_drift['dataset_drift']}")
print(f"Features drifted: {dataset_drift['number_of_drifted_columns']}/{dataset_drift['number_of_columns']}")
print(f"Share of drifted features: {dataset_drift['share_of_drifted_columns']:.1%}")
Evidentně generuje interaktivní HTML sestavy s vizualizacemi distribucí, histogramy překryvné a souhrnné tabulky. Statistický test je uveden pro každý prvek použitá (automaticky vybraná na základě typu dat), p-hodnota nebo testovací statistika, a vlajka drift/no-drift.
Testovací sada s vlastními prahovými hodnotami
Per integrare Evidently in una pipeline CI/CD o in un workflow Airflow/Prefect, la Test Suite di Evidently e lo strumento giusto: permette di definire soglie precise e restituisce pass/fail in modo programmatico.
from evidently.test_suite import TestSuite
from evidently.tests import (
TestNumberOfDriftedColumns,
TestShareOfDriftedColumns,
TestColumnDrift,
TestDatasetDrift
)
# --- Test Suite con soglie personalizzate ---
drift_test_suite = TestSuite(tests=[
# Non più del 20% delle feature deve driftare
TestShareOfDriftedColumns(lt=0.2),
# Feature critiche: test individuali con soglie aggressive
TestColumnDrift(
column_name="monthly_charges",
stattest="ks",
stattest_threshold=0.05
),
TestColumnDrift(
column_name="contract_type",
stattest="chi2",
stattest_threshold=0.05
),
TestColumnDrift(
column_name="num_support_tickets",
stattest="psi",
stattest_threshold=0.1 # PSI < 0.1 = no drift
),
# Dataset-level drift test
TestDatasetDrift(stattest_threshold=0.05),
])
drift_test_suite.run(
reference_data=reference_data[feature_columns],
current_data=current_data[feature_columns]
)
# Risultato pass/fail per la pipeline
test_result = drift_test_suite.as_dict()
all_passed = all(
test["status"] == "SUCCESS"
for test in test_result["tests"]
)
if not all_passed:
print("DRIFT DETECTED - Pipeline triggering retraining...")
for test in test_result["tests"]:
if test["status"] != "SUCCESS":
print(f" FAILED: {test['name']} - {test['description']}")
# Trigger retraining (vedi sezione retraining)
trigger_retraining_pipeline()
else:
print("All drift tests passed - Model healthy")
Monitoring con NannyML: Performance Senza Label
NannyML risolve uno dei problemi più difficili del model monitoring: stimare le performance del modello quando le etichette reali non sono ancora disponibili. In un modello di churn prediction, le etichette (se il cliente ha effettivamente churned) potrebbero arrivare solo 90 giorni dopo la predizione. NannyML usa il metodo Confidence-Based Performance Estimation (CBPE) per stimare accuracy, F1 e AUC in real-time usando solo le distribuzioni degli score.
pip install nannyml
import nannyml as nml
import pandas as pd
# Carica i dati
reference_df = pd.read_parquet("data/reference_with_targets.parquet")
analysis_df = pd.read_parquet("data/production_last_30_days.parquet")
# --- CBPE: Stima delle performance senza label ---
estimator = nml.CBPE(
y_pred_proba="churn_probability",
y_pred="churn_predicted",
y_true="churned", # presente solo nel reference
timestamp_column_name="prediction_date",
problem_type="binary_classification",
metrics=["roc_auc", "f1", "precision", "recall"],
chunk_size=500 # 500 predizioni per chunk temporale
)
estimator.fit(reference_df)
results = estimator.estimate(analysis_df)
# Visualizza risultati con alert automatici
figure = results.plot()
figure.show()
# Estrai metriche per alerting
estimated_metrics = results.to_df()
latest_chunk = estimated_metrics.tail(1)
auc_lower = latest_chunk["estimated_roc_auc_lower_confidence_boundary"].values[0]
if auc_lower < 0.70:
print(f"ALERT: AUC stimato < 0.70 (lower bound: {auc_lower:.3f})")
trigger_retraining_pipeline()
# --- Univariate Drift Detection ---
univariate_calc = nml.UnivariateDriftCalculator(
column_names=["monthly_charges", "tenure_months", "num_tickets"],
timestamp_column_name="prediction_date",
continuous_methods=["kolmogorov_smirnov", "jensen_shannon"],
categorical_methods=["chi2", "jensen_shannon"],
chunk_size=500
)
univariate_calc.fit(reference_df)
drift_results = univariate_calc.calculate(analysis_df)
# Plotta il drift nel tempo per ogni feature
drift_figure = drift_results.filter(period="analysis").plot()
drift_figure.show()
NannyML produce grafici temporali che mostrano l'evoluzione del drift nel tempo, con bande di confidenza e alert visivi. Questo e particolarmente utile per capire quando il drift e iniziato e se sta peggiorando o stabilizzandosi.
Alibi Detect: Drift Detection Avanzata con MMD e LSDD
Alibi Detect (by Seldon) e la libreria di riferimento per detection avanzata che va oltre le statistiche univariate. Supporta MMD (Maximum Mean Discrepancy) per dati tabulari e immagini, LSDD (Least-Squares Density Difference) e rilevazione di outlier. E ideale quando si ha bisogno di rilevare drift multivariato complesso.
pip install alibi-detect
import numpy as np
from alibi_detect.cd import MMDDrift, KSDrift, TabularDrift
from alibi_detect.saving import save_detector, load_detector
# Carica dati di riferimento (numpy array)
X_ref = reference_data[feature_columns].values.astype(np.float32)
X_current = current_data[feature_columns].values.astype(np.float32)
# --- KS Drift per feature continue ---
ks_detector = KSDrift(
x_ref=X_ref,
p_val=0.05, # soglia p-value
alternative="two-sided"
)
ks_preds = ks_detector.predict(
X_current,
drift_type="batch",
return_p_val=True,
return_distance=True
)
print("KS Drift Results:")
print(f" Drift detected: {ks_preds['data']['is_drift']}")
print(f" p-values per feature: {ks_preds['data']['p_val']}")
print(f" Features drifted: {ks_preds['data']['is_drift'].sum()}")
# --- MMD Drift per rilevazione multivariata ---
# Più potente per distribuzioni complesse
mmd_detector = MMDDrift(
x_ref=X_ref,
backend="pytorch", # o "tensorflow"
p_val=0.05,
n_permutations=200 # più alto = più preciso ma più lento
)
mmd_preds = mmd_detector.predict(
X_current,
return_p_val=True,
return_distance=True
)
print(f"\nMMD Drift (multivariato):")
print(f" Drift detected: {mmd_preds['data']['is_drift']}")
print(f" p-value: {mmd_preds['data']['p_val']:.4f}")
print(f" MMD^2 statistic: {mmd_preds['data']['distance']:.6f}")
# --- TabularDrift: test ottimizzato per dati tabulari misti ---
tabular_detector = TabularDrift(
x_ref=X_ref,
p_val=0.05,
categories_per_feature={
4: None, # feature index 4 = contract_type (categorica)
6: None # feature index 6 = payment_method (categorica)
},
)
# Salva detector per riutilizzo
save_detector(tabular_detector, "models/drift_detector/")
# Successivamente carica e usa
# loaded_detector = load_detector("models/drift_detector/")
Architettura del Sistema di Monitoring
Un sistema di monitoring production-grade richiede più componenti integrati: un layer di raccolta metriche, uno storage time-series, un sistema di visualizzazione e un motore di alerting. La combinazione Prometheus + Grafana e lo standard open-source per questo use case, con ampia integrazione nell'ecosistema Kubernetes.
# monitoring_service.py
# Servizio FastAPI che espone metriche di drift per Prometheus
from fastapi import FastAPI, BackgroundTasks
from prometheus_client import Counter, Gauge, Histogram, generate_latest, CONTENT_TYPE_LATEST
from starlette.responses import Response
import pandas as pd
import schedule
import threading
import time
from datetime import datetime, timedelta
import logging
logger = logging.getLogger(__name__)
app = FastAPI(title="ML Monitoring Service")
# --- Prometheus Metrics ---
DRIFT_GAUGE = Gauge(
"ml_feature_drift_psi",
"Population Stability Index per feature",
labelnames=["feature_name", "model_name", "model_version"]
)
DATASET_DRIFT_GAUGE = Gauge(
"ml_dataset_drift_detected",
"1 se drift rilevato a livello dataset, 0 altrimenti",
labelnames=["model_name", "model_version"]
)
DRIFT_FEATURES_COUNT = Gauge(
"ml_drifted_features_count",
"Numero di feature che mostrano drift",
labelnames=["model_name"]
)
ESTIMATED_AUC = Gauge(
"ml_estimated_auc",
"AUC stimato via CBPE (NannyML)",
labelnames=["model_name", "model_version"]
)
PREDICTION_COUNT = Counter(
"ml_predictions_total",
"Numero totale di predizioni",
labelnames=["model_name", "outcome"]
)
INFERENCE_LATENCY = Histogram(
"ml_inference_duration_seconds",
"Latenza inference in secondi",
labelnames=["model_name"],
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5]
)
# --- Funzione di calcolo drift ---
def calculate_and_update_drift_metrics(
model_name: str,
model_version: str,
reference_data: pd.DataFrame,
current_data: pd.DataFrame,
feature_columns: list
):
"""Calcola PSI per ogni feature e aggiorna gauge Prometheus."""
from evidently.report import Report
from evidently.metrics import DatasetDriftMetric, DataDriftTable
report = Report(metrics=[
DatasetDriftMetric(stattest="psi"),
DataDriftTable(stattest="psi"),
])
report.run(
reference_data=reference_data[feature_columns],
current_data=current_data[feature_columns]
)
result = report.as_dict()
# Dataset-level drift
dataset_result = result["metrics"][0]["result"]
drift_detected = 1 if dataset_result["dataset_drift"] else 0
DATASET_DRIFT_GAUGE.labels(
model_name=model_name,
model_version=model_version
).set(drift_detected)
DRIFT_FEATURES_COUNT.labels(
model_name=model_name
).set(dataset_result["number_of_drifted_columns"])
# Per-feature PSI
feature_results = result["metrics"][1]["result"]["drift_by_columns"]
for feature_name, feature_data in feature_results.items():
psi_value = feature_data.get("stattest_threshold", 0)
actual_stat = feature_data.get("drift_score", 0)
DRIFT_GAUGE.labels(
feature_name=feature_name,
model_name=model_name,
model_version=model_version
).set(actual_stat)
logger.info(f"Drift metrics updated for {model_name} v{model_version}")
return drift_detected
@app.get("/metrics")
async def metrics():
"""Endpoint Prometheus metrics."""
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
@app.post("/drift/check")
async def trigger_drift_check(background_tasks: BackgroundTasks):
"""Trigger manuale del drift check."""
background_tasks.add_task(run_drift_check_job)
return {"status": "drift check started"}
@app.get("/health")
async def health():
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
Configurazione Prometheus e Grafana
La configurazione di Prometheus per scraping delle metriche ML e semplice: aggiungi il monitoring service come target nel file di configurazione.
# prometheus.yml
global:
scrape_interval: 60s
evaluation_interval: 60s
rule_files:
- "ml_drift_alerts.yml"
alerting:
alertmanagers:
- static_configs:
- targets: ["alertmanager:9093"]
scrape_configs:
- job_name: "ml-monitoring"
static_configs:
- targets: ["ml-monitoring-service:8000"]
metrics_path: "/metrics"
scrape_interval: 60s
- job_name: "model-serving"
static_configs:
- targets: ["fastapi-serving:8080"]
metrics_path: "/metrics"
---
# ml_drift_alerts.yml
groups:
- name: ml_drift_alerts
rules:
- alert: HighFeatureDrift
expr: ml_feature_drift_psi{} > 0.2
for: 5m
labels:
severity: warning
annotations:
summary: "High drift detected on feature {{ $labels.feature_name }}"
description: "PSI = {{ $value | humanize }} for feature {{ $labels.feature_name }}"
- alert: DatasetDriftDetected
expr: ml_dataset_drift_detected == 1
for: 10m
labels:
severity: critical
annotations:
summary: "Dataset-level drift detected for model {{ $labels.model_name }}"
description: "Model performance may be degraded. Consider retraining."
- alert: LowEstimatedAUC
expr: ml_estimated_auc < 0.70
for: 15m
labels:
severity: critical
annotations:
summary: "Estimated AUC dropped below threshold"
description: "Estimated AUC = {{ $value | humanize }} for model {{ $labels.model_name }}"
Dashboard Grafana: Metriche Chiave da Monitorare
- PSI per feature: heatmap con soglie 0.1/0.2 colorate (verde/giallo/rosso)
- Drift score nel tempo: grafico a linee per feature critiche
- AUC stimata (CBPE): time series con bande di confidenza
- Numero di feature driftate: gauge con soglia di alert
- Distribuzione predizioni: istogramma score di probabilità
- Latenza e throughput: panel standard per SLA monitoring
Pipeline di Retraining Automatico
Rilevare il drift e necessario ma non sufficiente: bisogna anche reagire automaticamente. Una pipeline di retraining automatico deve essere attivata da alert di drift, validare il nuovo modello prima di sostituire quello in produzione e garantire rollback in caso di regressione delle performance.
# retraining_pipeline.py
# Pipeline di retraining automatico con MLflow
import mlflow
import mlflow.sklearn
import pandas as pd
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from datetime import datetime
import logging
import requests
logger = logging.getLogger(__name__)
MLFLOW_TRACKING_URI = "http://mlflow-server:5000"
MODEL_NAME = "churn-prediction"
MIN_AUC_THRESHOLD = 0.72 # AUC minima per promuovere in produzione
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
def load_fresh_training_data() -> pd.DataFrame:
"""Carica dati aggiornati per il retraining."""
# In produzione: query al feature store o data warehouse
df = pd.read_parquet("data/training_data_fresh.parquet")
logger.info(f"Loaded {len(df)} training samples")
return df
def train_new_model(df: pd.DataFrame) -> tuple:
"""Addestra un nuovo modello con i dati freschi."""
feature_columns = [
"age", "tenure_months", "monthly_charges",
"total_charges", "num_support_tickets",
"contract_type_encoded", "payment_method_encoded"
]
target_column = "churned"
X = df[feature_columns]
y = df[target_column]
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
model = GradientBoostingClassifier(
n_estimators=200,
max_depth=4,
learning_rate=0.05,
subsample=0.8,
random_state=42
)
model.fit(X_train, y_train)
y_pred_proba = model.predict_proba(X_val)[:, 1]
y_pred = model.predict(X_val)
metrics = {
"auc": roc_auc_score(y_val, y_pred_proba),
"f1": f1_score(y_val, y_pred),
"precision": precision_score(y_val, y_pred),
"recall": recall_score(y_val, y_pred),
"val_samples": len(X_val)
}
return model, metrics, feature_columns
def register_and_promote_model(
model,
metrics: dict,
feature_columns: list,
trigger_reason: str
) -> bool:
"""Registra il modello in MLflow e promuovilo in produzione se supera la soglia."""
with mlflow.start_run(run_name=f"retrain_{datetime.utcnow().strftime('%Y%m%d_%H%M')}") as run:
# Log params
mlflow.log_param("trigger_reason", trigger_reason)
mlflow.log_param("training_timestamp", datetime.utcnow().isoformat())
mlflow.log_param("features", feature_columns)
# Log metrics
for metric_name, metric_value in metrics.items():
if isinstance(metric_value, (int, float)):
mlflow.log_metric(metric_name, metric_value)
# Log model
mlflow.sklearn.log_model(
model,
artifact_path="model",
registered_model_name=MODEL_NAME
)
run_id = run.info.run_id
logger.info(f"Model registered with run_id={run_id}, AUC={metrics['auc']:.4f}")
# Promuovi in produzione se supera la soglia
if metrics["auc"] >= MIN_AUC_THRESHOLD:
client = mlflow.tracking.MlflowClient()
latest_version = client.get_latest_versions(MODEL_NAME, stages=["None"])[0]
client.transition_model_version_stage(
name=MODEL_NAME,
version=latest_version.version,
stage="Production",
archive_existing_versions=True
)
logger.info(f"Model v{latest_version.version} promoted to Production")
send_slack_notification(f"Model retrained and promoted. AUC={metrics['auc']:.4f}")
return True
else:
logger.warning(f"Model AUC {metrics['auc']:.4f} below threshold {MIN_AUC_THRESHOLD}. Not promoting.")
send_slack_notification(
f"Retraining completed but model below threshold. AUC={metrics['auc']:.4f}. Manual review needed.",
level="warning"
)
return False
def send_slack_notification(message: str, level: str = "info"):
"""Invia notifica Slack (o webhook generico)."""
webhook_url = "https://hooks.slack.com/services/YOUR/WEBHOOK/URL"
color = "#36a64f" if level == "info" else "#ff0000"
payload = {
"attachments": [{
"color": color,
"title": "MLOps Retraining Alert",
"text": message,
"footer": f"ML Platform | {datetime.utcnow().isoformat()}"
}]
}
try:
requests.post(webhook_url, json=payload, timeout=5)
except Exception as e:
logger.error(f"Failed to send Slack notification: {e}")
def run_retraining_pipeline(trigger_reason: str = "drift_detected"):
"""Entry point della pipeline di retraining."""
logger.info(f"Starting retraining pipeline. Trigger: {trigger_reason}")
df = load_fresh_training_data()
model, metrics, feature_columns = train_new_model(df)
promoted = register_and_promote_model(model, metrics, feature_columns, trigger_reason)
logger.info(f"Retraining pipeline completed. Promoted: {promoted}")
return promoted
if __name__ == "__main__":
run_retraining_pipeline(trigger_reason="manual_trigger")
Strategie di Trigger per il Retraining
Definire quando fare retraining e tanto importante quanto come farlo. Esistono tre strategie principali, ognuna con vantaggi e limitazioni:
Strategie di Retraining a Confronto
- Schedule-based (calendario): Retraining periodico fisso (settimanale, mensile). Semplice da implementare ma inefficiente: fa retraining anche quando non serve e potrebbe non fare retraining abbastanza spesso durante periodi di drift rapido.
- Performance-based: Retraining quando le metriche di performance scendono sotto una soglia. Richiede ground truth disponibile rapidamente. Ideale per modelli con feedback loop veloce (es. click-through rate, conversion).
- Drift-based: Retraining quando viene rilevato drift statisticamente significativo nelle feature o nelle predizioni. Non richiede label. Approccio proattivo che previene il degrado prima che impatti le performance. Rischio di falsi positivi.
- Ibrido (raccomandato): Combina drift detection come trigger primario con validazione delle performance come gate di qualità prima della promozione in produzione. Aggiunge anche un retraining periodico di fallback.
Configurazione Completa con Docker Compose
Per ambienti di sviluppo e staging, Docker Compose permette di avviare l'intero stack di monitoring in modo rapido e riproducibile.
# docker-compose.monitoring.yml
version: "3.8"
services:
# ML Monitoring Service (FastAPI + Evidently)
ml-monitoring:
build: ./monitoring_service
ports:
- "8001:8000"
environment:
- MLFLOW_TRACKING_URI=http://mlflow:5000
- REFERENCE_DATA_PATH=/data/reference.parquet
volumes:
- ./data:/data
- ./reports:/reports
depends_on:
- mlflow
# MLflow Tracking Server
mlflow:
image: ghcr.io/mlflow/mlflow:v2.11.0
ports:
- "5000:5000"
command: >
mlflow server
--host 0.0.0.0
--port 5000
--backend-store-uri postgresql://mlflow:mlflow@postgres/mlflow
--default-artifact-root s3://mlflow-artifacts/
depends_on:
- postgres
# PostgreSQL per MLflow
postgres:
image: postgres:15-alpine
environment:
- POSTGRES_USER=mlflow
- POSTGRES_PASSWORD=mlflow
- POSTGRES_DB=mlflow
volumes:
- postgres_data:/var/lib/postgresql/data
# Prometheus
prometheus:
image: prom/prometheus:v2.50.1
ports:
- "9090:9090"
volumes:
- ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
- ./monitoring/alerts.yml:/etc/prometheus/alerts.yml
- prometheus_data:/prometheus
command:
- "--config.file=/etc/prometheus/prometheus.yml"
- "--storage.tsdb.retention.time=30d"
# Grafana
grafana:
image: grafana/grafana:10.3.3
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
- GF_USERS_ALLOW_SIGN_UP=false
volumes:
- ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards
- ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources
- grafana_data:/var/lib/grafana
depends_on:
- prometheus
# Alertmanager
alertmanager:
image: prom/alertmanager:v0.27.0
ports:
- "9093:9093"
volumes:
- ./monitoring/alertmanager.yml:/etc/alertmanager/alertmanager.yml
volumes:
postgres_data:
prometheus_data:
grafana_data:
Rozpočet <5 000 EUR/rok pro malé a střední podniky
Kompletní systém detekce posunu nevyžaduje podnikový rozpočet. S přístupem open source a cloud-native, je možné udržovat robustní systém s minimálními náklady:
- Evidentně AI + NannyML: Open source, zdarma
- MLflow (samohoštěný): Open-source, pouze náklady na infrastrukturu
- Prometheus + Grafana: Open source, zdarma
- Výpočet (VPS/cloud): ~50-100 EUR/měsíc pro průměrný VM (600-1200 EUR/rok)
- Úložiště kompatibilní s S3: ~20 EUR/měsíc za 500 GB (240 EUR/rok)
- Odhadovaný součet: ~1000-2000 EUR/rok za full stack
Nejlepší postupy pro detekci posunu ve výrobě
Kontrolní seznam výroby
- Definujte statistický základ před nasazením: Spusťte detekci posunu proti sobě na ověřovací sadě pro kalibraci prahů. PSI > 0 na datech stacionární indikuje překročení prahu.
- Použijte vhodná časová okna: Neporovnávejte veškerý provoz historické s dneškem. Použijte posuvná okna (7/14/30 dní) k zachycení nedávného posunu.
- Upřednostněte funkce podle důležitosti: Monitorujte agresivněji Vlastnosti SHAP s vysokým dopadem. Ne všechny drifty jsou stejně kritické.
- Rozlišujte technický drift od sémantického driftu: Změna formátu pole (např. od řetězce k číslu) a technická chyba, nikoli posun ML. Přidat samostatné kontroly kvality dat.
- Vyhněte se bdělé únavě: Nejprve nastavte konzervativní prahové hodnoty a časem se zpřesňuje. Příliš mnoho upozornění vede k jejich ignorování.
- Protokolování rozhodnutí o rekvalifikaci: Každá rekvalifikace musí být vykreslený pomocí MLflow, včetně důvodu spuštění, před/po metrikách a propagovaná verze modelu.
- Testování samotného detektoru: Pravidelně kontrolujte, že systém detekce funguje správně s testováním vkládání dat (vstřikování syntetického driftu a ověřte, zda je detekován).
Anti-vzory, kterým je třeba se vyhnout
- Kvalitní bezbránové automatické přeškolení: Nepropagujte v vytvoření nově trénovaného modelu bez ověřování výkonu. Přeškolení na kontaminovaná data může model zhoršit.
- Pouze monitorovací výstup: Sledovat pouze předpovědi bez vstupní funkce znemožňují diagnostikovat příčinu driftu.
- Pevné prahové hodnoty pro všechny modely: Každý model má citlivost odlišné od driftu. PSI > 0,2 může být pro kritický model katastrofální a pro model s nízkou prioritou irelevantní.
- Ignorujte posun konceptu: Pokud se štítky zpětné vazby neshromažďují z produkčního modelu není možné přímo detekovat posun konceptu. Investujte do infrastruktury zpětné vazby.
Závěry a další kroky
Automatická detekce driftu a rekvalifikační systém je srdcem každého vyspělého MLOps. Bez aktivního monitorování ML modely ve výrobě tiše degradují a generují chybná rozhodnutí, která mohou stát mnohem více než náklady na samotný monitorovací systém.
V této příručce jsme vytvořili kompletní systém: z teoretického porozumění čtyři typy driftu až po praktickou implementaci s Evidently AI pro interaktivní zprávy, NannyML pro odhad výkonu bez štítků a Alibi Detect pro detekci pokročilé vícerozměrné. Vše jsme integrovali s Prometheem, Grafanou a potrubím automatické přeškolení s MLflow.
Dalším krokem je integrace tohoto systému se službou FastAPI, kterou jsme viděli v předchozím článku a se škálováním Kubernetes, které uvidíme v dalším. S těmito komponenty, budete mít kompletní, produkční a udržovatelný systém MLOps.
Série MLOps pokračuje
- Předchozí článek: Sledování experimentu s MLflow: Kompletní průvodce - zaznamenávat experimenty a porovnávat modely
- Další článek: Servírovací modely: FastAPI + Uvicorn ve výrobě - vytvářet škálovatelná inferenční API
- Další informace: Škálování ML na Kubernetes - zorganizovat nasazení s KubeFlow a Seldonem
- Související série: Pokročilé hluboké učení - monitorování pro komplexní neurální modely







