diff --git a/data/processed/test_results.json b/data/processed/test_results.json index 5b23fc1..f3f72bd 100644 --- a/data/processed/test_results.json +++ b/data/processed/test_results.json @@ -1,37 +1,51 @@ { "loss_metrics": { - "loss": 2.8661977450052896, - "loss_size": 0.44916408757368725, - "loss_pdi": 0.5041926403840383, - "loss_ee": 0.9021427234013876, - "loss_delivery": 0.5761533578236898, - "loss_biodist": 0.4019051690896352, - "loss_toxic": 0.03263980595511384, - "acc_pdi": 0.7633587786259542, - "acc_ee": 0.6641221374045801, - "acc_toxic": 0.9702970297029703 + "loss": 2.5374555587768555, + "loss_size": 0.1886825958887736, + "loss_pdi": 0.45798932512601215, + "loss_ee": 0.829658567905426, + "loss_delivery": 0.4857304096221924, + "loss_biodist": 0.5346279243628184, + "loss_toxic": 0.04076674363265435, + "acc_pdi": 0.7862595419847328, + "acc_ee": 0.6793893129770993, + "acc_toxic": 0.9801980198019802 }, "detailed_metrics": { "size": { - "mse": 0.41126506251447736, - "rmse": 0.6412995107704959, - "mae": 0.41415552388095633, - "r2": -0.9333718010891026 + "mse": 0.1669999969286325, + "rmse": 0.4086563310761654, + "mae": 0.26111859684375066, + "r2": 0.2149270281561566 }, "delivery": { - "mse": 0.6277965050686476, - "rmse": 0.7923361061245711, - "mae": 0.5387302115022443, - "r2": 0.24206702565575944 + "mse": 0.5193460523366603, + "rmse": 0.7206566813238189, + "mae": 0.4828052782115008, + "r2": 0.37299826459145 }, "pdi": { - "accuracy": 0.7633587786259542 + "accuracy": 0.7862595419847328, + "precision": 0.7282763532763532, + "recall": 0.6907738095238095, + "f1": 0.7041935483870968 }, "ee": { - "accuracy": 0.6641221374045801 + "accuracy": 0.6793893129770993, + "precision": 0.612247574088644, + "recall": 0.6062951496388029, + "f1": 0.6069449904342585 }, "toxic": { - "accuracy": 0.9702970297029703 + "accuracy": 0.9801980198019802, + "precision": 0.5, + "recall": 0.4900990099009901, + "f1": 0.495 + }, + "biodist": { + "n_samples": 101, + "kl_divergence": 0.2931957937514963, + "js_divergence": 0.07706768601895059 } } } \ No newline at end of file diff --git a/lnp_ml/modeling/predict.py b/lnp_ml/modeling/predict.py index 1fe6b43..88facd5 100644 --- a/lnp_ml/modeling/predict.py +++ b/lnp_ml/modeling/predict.py @@ -217,15 +217,31 @@ def test( """ import json import numpy as np + from scipy.special import rel_entr from sklearn.metrics import ( mean_squared_error, mean_absolute_error, r2_score, accuracy_score, - classification_report, + precision_score, + recall_score, + f1_score, ) from lnp_ml.modeling.trainer import validate + def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float: + """计算 KL 散度 KL(p || q)""" + p = np.clip(p, eps, 1.0) + q = np.clip(q, eps, 1.0) + return float(np.sum(rel_entr(p, q), axis=-1).mean()) + + def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float: + """计算 JS 散度""" + p = np.clip(p, eps, 1.0) + q = np.clip(q, eps, 1.0) + m = 0.5 * (p + q) + return float(0.5 * (np.sum(rel_entr(p, m), axis=-1) + np.sum(rel_entr(q, m), axis=-1)).mean()) + logger.info(f"Using device: {device}") device_obj = torch.device(device) @@ -287,6 +303,9 @@ def test( y_pred = np.array(predictions["pdi"])[mask] results["detailed_metrics"]["pdi"] = { "accuracy": float(accuracy_score(y_true, y_pred)), + "precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)), + "recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)), + "f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)), } # 分类指标:EE @@ -299,6 +318,9 @@ def test( y_pred = np.array(predictions["ee"])[mask] results["detailed_metrics"]["ee"] = { "accuracy": float(accuracy_score(y_true, y_pred)), + "precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)), + "recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)), + "f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)), } # 分类指标:toxic @@ -309,6 +331,28 @@ def test( y_pred = np.array(predictions["toxic"])[mask.values] results["detailed_metrics"]["toxic"] = { "accuracy": float(accuracy_score(y_true, y_pred)), + "precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)), + "recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)), + "f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)), + } + + # 分布指标:biodist + biodist_cols = [ + "Biodistribution_lymph_nodes", "Biodistribution_heart", "Biodistribution_liver", + "Biodistribution_spleen", "Biodistribution_lung", "Biodistribution_kidney", "Biodistribution_muscle" + ] + if all(c in test_df.columns for c in biodist_cols): + biodist_true = test_df[biodist_cols].values + biodist_pred = np.array(predictions["biodist"]) + # mask: 有效样本是 sum > 0 且无 NaN + mask = (biodist_true.sum(axis=1) > 0) & (~np.isnan(biodist_true).any(axis=1)) + if mask.any(): + y_true = biodist_true[mask] + y_pred = biodist_pred[mask] + results["detailed_metrics"]["biodist"] = { + "n_samples": int(mask.sum()), + "kl_divergence": kl_divergence(y_true, y_pred), + "js_divergence": js_divergence(y_true, y_pred), } # 打印结果 diff --git a/lnp_ml/modeling/train_cv.py b/lnp_ml/modeling/train_cv.py index 270b850..1b35a8b 100644 --- a/lnp_ml/modeling/train_cv.py +++ b/lnp_ml/modeling/train_cv.py @@ -391,13 +391,30 @@ def test( 使用每个 fold 的模型在对应的测试集上评估,然后汇总结果。 """ + from scipy.special import rel_entr from sklearn.metrics import ( mean_squared_error, mean_absolute_error, r2_score, accuracy_score, + precision_score, + recall_score, + f1_score, ) + def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float: + """计算 KL 散度 KL(p || q)""" + p = np.clip(p, eps, 1.0) + q = np.clip(q, eps, 1.0) + return float(np.sum(rel_entr(p, q), axis=-1).mean()) + + def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float: + """计算 JS 散度""" + p = np.clip(p, eps, 1.0) + q = np.clip(q, eps, 1.0) + m = 0.5 * (p + q) + return float(0.5 * (np.sum(rel_entr(p, m), axis=-1) + np.sum(rel_entr(q, m), axis=-1)).mean()) + logger.info(f"Using device: {device}") device = torch.device(device) @@ -413,10 +430,10 @@ def test( fold_results = [] # 用于汇总所有 fold 的预测 all_preds = { - "size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [] + "size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": [] } all_targets = { - "size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [] + "size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": [] } for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"): @@ -522,6 +539,14 @@ def test( toxic_targets = targets["toxic"][mask].cpu().numpy().astype(int) fold_preds["toxic"].extend(toxic_preds.tolist()) fold_targets["toxic"].extend(toxic_targets.tolist()) + + # Biodist (distribution) + if "biodist" in masks and masks["biodist"].any(): + mask = masks["biodist"] + biodist_preds = outputs["biodist"][mask].cpu().numpy() + biodist_targets = targets["biodist"][mask].cpu().numpy() + fold_preds["biodist"].extend(biodist_preds.tolist()) + fold_targets["biodist"].extend(biodist_targets.tolist()) # 计算当前 fold 的指标 fold_metrics = {"fold_idx": fold_idx, "n_samples": len(test_df)} @@ -546,8 +571,21 @@ def test( fold_metrics[task] = { "n": len(p), "accuracy": float(accuracy_score(t, p)), + "precision": float(precision_score(t, p, average="macro", zero_division=0)), + "recall": float(recall_score(t, p, average="macro", zero_division=0)), + "f1": float(f1_score(t, p, average="macro", zero_division=0)), } + # 分布任务指标 + if fold_preds["biodist"]: + p = np.array(fold_preds["biodist"]) + t = np.array(fold_targets["biodist"]) + fold_metrics["biodist"] = { + "n": len(p), + "kl_divergence": kl_divergence(t, p), + "js_divergence": js_divergence(t, p), + } + fold_results.append(fold_metrics) # 汇总到全局 @@ -564,6 +602,10 @@ def test( for task in ["pdi", "ee", "toxic"]: if task in fold_metrics and isinstance(fold_metrics[task], dict): log_parts.append(f"{task}_acc={fold_metrics[task]['accuracy']:.4f}") + log_parts.append(f"{task}_f1={fold_metrics[task]['f1']:.4f}") + if "biodist" in fold_metrics and isinstance(fold_metrics["biodist"], dict): + log_parts.append(f"biodist_KL={fold_metrics['biodist']['kl_divergence']:.4f}") + log_parts.append(f"biodist_JS={fold_metrics['biodist']['js_divergence']:.4f}") logger.info(", ".join(log_parts)) # 计算跨 fold 汇总统计 @@ -581,12 +623,26 @@ def test( for task in ["pdi", "ee", "toxic"]: accs = [r[task]["accuracy"] for r in fold_results if task in r and isinstance(r[task], dict)] + f1s = [r[task]["f1"] for r in fold_results if task in r and isinstance(r[task], dict)] if accs: summary_stats[task] = { "accuracy_mean": float(np.mean(accs)), "accuracy_std": float(np.std(accs)), + "f1_mean": float(np.mean(f1s)), + "f1_std": float(np.std(f1s)), } + # 分布任务汇总 + kls = [r["biodist"]["kl_divergence"] for r in fold_results if "biodist" in r and isinstance(r["biodist"], dict)] + jss = [r["biodist"]["js_divergence"] for r in fold_results if "biodist" in r and isinstance(r["biodist"], dict)] + if kls: + summary_stats["biodist"] = { + "kl_mean": float(np.mean(kls)), + "kl_std": float(np.std(kls)), + "js_mean": float(np.mean(jss)), + "js_std": float(np.std(jss)), + } + # 计算整体 pooled 指标 overall = {} for task in ["size", "delivery"]: @@ -608,8 +664,21 @@ def test( overall[task] = { "n_samples": len(p), "accuracy": float(accuracy_score(t, p)), + "precision": float(precision_score(t, p, average="macro", zero_division=0)), + "recall": float(recall_score(t, p, average="macro", zero_division=0)), + "f1": float(f1_score(t, p, average="macro", zero_division=0)), } + # 分布任务 + if all_preds["biodist"]: + p = np.array(all_preds["biodist"]) + t = np.array(all_targets["biodist"]) + overall["biodist"] = { + "n_samples": len(p), + "kl_divergence": kl_divergence(t, p), + "js_divergence": js_divergence(t, p), + } + # 打印汇总结果 logger.info("\n" + "=" * 60) logger.info("CV TEST EVALUATION RESULTS") @@ -619,15 +688,19 @@ def test( for task, stats in summary_stats.items(): if "rmse_mean" in stats: logger.info(f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}") - else: - logger.info(f" {task}: Accuracy={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}") + elif "accuracy_mean" in stats: + logger.info(f" {task}: Accuracy={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}") + elif "kl_mean" in stats: + logger.info(f" {task}: KL={stats['kl_mean']:.4f}±{stats['kl_std']:.4f}, JS={stats['js_mean']:.4f}±{stats['js_std']:.4f}") logger.info(f"\n[Overall (all samples pooled)]") for task, metrics in overall.items(): if "rmse" in metrics: logger.info(f" {task} (n={metrics['n_samples']}): RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, R²={metrics['r2']:.4f}") - else: - logger.info(f" {task} (n={metrics['n_samples']}): Accuracy={metrics['accuracy']:.4f}") + elif "accuracy" in metrics: + logger.info(f" {task} (n={metrics['n_samples']}): Accuracy={metrics['accuracy']:.4f}, Precision={metrics['precision']:.4f}, Recall={metrics['recall']:.4f}, F1={metrics['f1']:.4f}") + elif "kl_divergence" in metrics: + logger.info(f" {task} (n={metrics['n_samples']}): KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}") # 保存结果 results = { diff --git a/models/finetune_cv/test_results.json b/models/finetune_cv/test_results.json index aa9bc22..72d2e3f 100644 --- a/models/finetune_cv/test_results.json +++ b/models/finetune_cv/test_results.json @@ -17,15 +17,29 @@ }, "pdi": { "n": 95, - "accuracy": 0.6105263157894737 + "accuracy": 0.6105263157894737, + "precision": 0.20350877192982456, + "recall": 0.3333333333333333, + "f1": 0.25272331154684097 }, "ee": { "n": 95, - "accuracy": 0.6631578947368421 + "accuracy": 0.6631578947368421, + "precision": 0.22580645161290322, + "recall": 0.328125, + "f1": 0.267515923566879 }, "toxic": { "n": 66, - "accuracy": 0.8939393939393939 + "accuracy": 0.8939393939393939, + "precision": 0.44696969696969696, + "recall": 0.5, + "f1": 0.472 + }, + "biodist": { + "n": 66, + "kl_divergence": 0.8735123101690443, + "js_divergence": 0.2203766048579219 } }, { @@ -45,15 +59,29 @@ }, "pdi": { "n": 195, - "accuracy": 0.7076923076923077 + "accuracy": 0.7076923076923077, + "precision": 0.35384615384615387, + "recall": 0.5, + "f1": 0.4144144144144144 }, "ee": { "n": 195, - "accuracy": 0.4205128205128205 + "accuracy": 0.4205128205128205, + "precision": 0.14017094017094017, + "recall": 0.3333333333333333, + "f1": 0.19735258724428398 }, "toxic": { "n": 123, - "accuracy": 1.0 + "accuracy": 1.0, + "precision": 1.0, + "recall": 1.0, + "f1": 1.0 + }, + "biodist": { + "n": 123, + "kl_divergence": 1.218490162626268, + "js_divergence": 0.3069412988090679 } }, { @@ -73,15 +101,29 @@ }, "pdi": { "n": 51, - "accuracy": 0.8823529411764706 + "accuracy": 0.8823529411764706, + "precision": 0.29411764705882354, + "recall": 0.3333333333333333, + "f1": 0.3125 }, "ee": { "n": 51, - "accuracy": 0.8431372549019608 + "accuracy": 0.8431372549019608, + "precision": 0.28104575163398693, + "recall": 0.3333333333333333, + "f1": 0.3049645390070922 }, "toxic": { "n": 47, - "accuracy": 0.851063829787234 + "accuracy": 0.851063829787234, + "precision": 0.425531914893617, + "recall": 0.5, + "f1": 0.4597701149425288 + }, + "biodist": { + "n": 45, + "kl_divergence": 1.0902737810034395, + "js_divergence": 0.27372590260311874 } }, { @@ -101,15 +143,29 @@ }, "pdi": { "n": 66, - "accuracy": 0.8484848484848485 + "accuracy": 0.8484848484848485, + "precision": 0.42424242424242425, + "recall": 0.5, + "f1": 0.4590163934426229 }, "ee": { "n": 66, - "accuracy": 0.18181818181818182 + "accuracy": 0.18181818181818182, + "precision": 0.203921568627451, + "recall": 0.2707070707070707, + "f1": 0.12297410192147035 }, "toxic": { "n": 62, - "accuracy": 1.0 + "accuracy": 1.0, + "precision": 1.0, + "recall": 1.0, + "f1": 1.0 + }, + "biodist": { + "n": 62, + "kl_divergence": 0.9434472801013084, + "js_divergence": 0.20391570021898991 } }, { @@ -129,15 +185,29 @@ }, "pdi": { "n": 27, - "accuracy": 0.8888888888888888 + "accuracy": 0.8888888888888888, + "precision": 0.4444444444444444, + "recall": 0.5, + "f1": 0.47058823529411764 }, "ee": { "n": 27, - "accuracy": 0.5925925925925926 + "accuracy": 0.5925925925925926, + "precision": 0.19753086419753085, + "recall": 0.3333333333333333, + "f1": 0.24806201550387597 }, "toxic": { "n": 15, - "accuracy": 1.0 + "accuracy": 1.0, + "precision": 1.0, + "recall": 1.0, + "f1": 1.0 + }, + "biodist": { + "n": 15, + "kl_divergence": 0.7615404990024607, + "js_divergence": 0.20916908426734354 } } ], @@ -156,15 +226,27 @@ }, "pdi": { "accuracy_mean": 0.7875890604063979, - "accuracy_std": 0.11016791908756088 + "accuracy_std": 0.11016791908756088, + "f1_mean": 0.3818484709395992, + "f1_std": 0.08529090446864619 }, "ee": { "accuracy_mean": 0.5402437489124795, - "accuracy_std": 0.22467627690136344 + "accuracy_std": 0.22467627690136344, + "f1_mean": 0.2281738334487203, + "f1_std": 0.063019179670124 }, "toxic": { "accuracy_mean": 0.9490006447453256, - "accuracy_std": 0.06391582554207781 + "accuracy_std": 0.06391582554207781, + "f1_mean": 0.7863540229885058, + "f1_std": 0.26169039387919035 + }, + "biodist": { + "kl_mean": 0.9774528065805042, + "kl_std": 0.1608761675751487, + "js_mean": 0.24282571815128842, + "js_std": 0.04053726747029075 } }, "overall": { @@ -184,15 +266,29 @@ }, "pdi": { "n_samples": 434, - "accuracy": 0.7396313364055299 + "accuracy": 0.7396313364055299, + "precision": 0.18490783410138248, + "recall": 0.25, + "f1": 0.21258278145695364 }, "ee": { "n_samples": 434, - "accuracy": 0.4976958525345622 + "accuracy": 0.4976958525345622, + "precision": 0.21063404810247777, + "recall": 0.2839160839160839, + "f1": 0.23684873775319115 }, "toxic": { "n_samples": 313, - "accuracy": 0.9552715654952076 + "accuracy": 0.9552715654952076, + "precision": 0.4776357827476038, + "recall": 0.5, + "f1": 0.48856209150326796 + }, + "biodist": { + "n_samples": 311, + "kl_divergence": 1.0498561462079121, + "js_divergence": 0.25851000311532496 } } } \ No newline at end of file