lnp_ml/scripts/evaluate_external.py
2026-01-21 16:20:10 +08:00

205 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""评估外部数据 cross-validation 结果的脚本"""
import json
from pathlib import Path
from typing import Dict
import numpy as np
import pandas as pd
import typer
from loguru import logger
from sklearn.metrics import mean_squared_error, r2_score
from lnp_ml.config import EXTERNAL_DATA_DIR
app = typer.Typer()
def evaluate_split(
test_path: Path,
preds_path: Path,
y_col: str = "quantified_delivery",
) -> Dict[str, float]:
"""
评估单个 split 的预测结果。
Args:
test_path: groundtruth CSV 路径(包含 smiles 和 y_col
preds_path: 预测结果 CSV 路径(包含 smiles 和 y_col
y_col: Y 值列名
Returns:
Dict with rmse, r2, mse, mae, n_samples
"""
# 读取数据
test_df = pd.read_csv(test_path)
preds_df = pd.read_csv(preds_path)
# 检查列是否存在
if y_col not in test_df.columns:
raise ValueError(f"Column '{y_col}' not found in {test_path}")
if y_col not in preds_df.columns:
raise ValueError(f"Column '{y_col}' not found in {preds_path}")
# 直接按行对齐test.csv 和 preds.csv 顺序一致)
assert len(test_df) == len(preds_df), f"行数不匹配: test={len(test_df)}, preds={len(preds_df)}"
# 验证 smiles 列一致(如果存在)
if "smiles" in test_df.columns and "smiles" in preds_df.columns:
if not (test_df["smiles"].values == preds_df["smiles"].values).all():
logger.warning("smiles 列顺序不一致,请检查数据")
y_true = test_df[y_col].values
y_pred = preds_df[y_col].values
# 计算指标
mse = float(mean_squared_error(y_true, y_pred))
rmse = float(np.sqrt(mse))
r2 = float(r2_score(y_true, y_pred))
mae = float(np.mean(np.abs(y_true - y_pred)))
correlation = float(np.corrcoef(y_true, y_pred)[0, 1])
return {
"n_samples": len(y_true),
"mse": mse,
"rmse": rmse,
"mae": mae,
"r2": r2,
"correlation": correlation,
}
@app.command()
def main(
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
output_path: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON" / "evaluation_results.json",
y_col: str = "quantified_delivery",
):
"""
评估 all_amine_split_for_LiON 的 cross-validation 结果。
计算每个 split 和整体的 RMSE、R² 指标。
"""
logger.info(f"Evaluating cross-validation results from {data_dir}")
# 找到所有 cv_* 目录
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
if not cv_dirs:
logger.error(f"No cv_* directories found in {data_dir}")
raise typer.Exit(1)
logger.info(f"Found {len(cv_dirs)} splits: {[d.name for d in cv_dirs]}")
# 评估每个 split
split_results = {}
all_y_true = []
all_y_pred = []
for cv_dir in cv_dirs:
split_name = cv_dir.name
test_path = cv_dir / "test.csv"
preds_path = cv_dir / "preds.csv"
if not test_path.exists():
logger.warning(f" {split_name}: test.csv not found, skipping")
continue
if not preds_path.exists():
logger.warning(f" {split_name}: preds.csv not found, skipping")
continue
# 评估
metrics = evaluate_split(test_path, preds_path, y_col)
split_results[split_name] = metrics
logger.info(
f" {split_name}: n={metrics['n_samples']}, "
f"RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}"
)
# 收集所有数据用于计算整体指标
test_df = pd.read_csv(test_path)
preds_df = pd.read_csv(preds_path)
all_y_true.extend(test_df[y_col].tolist())
all_y_pred.extend(preds_df[y_col].tolist())
# 计算整体指标
all_y_true = np.array(all_y_true)
all_y_pred = np.array(all_y_pred)
overall_mse = float(mean_squared_error(all_y_true, all_y_pred))
overall_rmse = float(np.sqrt(overall_mse))
overall_r2 = float(r2_score(all_y_true, all_y_pred))
overall_mae = float(np.mean(np.abs(all_y_true - all_y_pred)))
overall_correlation = float(np.corrcoef(all_y_true, all_y_pred)[0, 1])
overall_results = {
"n_samples": len(all_y_true),
"mse": overall_mse,
"rmse": overall_rmse,
"mae": overall_mae,
"r2": overall_r2,
"correlation": overall_correlation,
}
# 计算 split 指标的均值和标准差
split_metrics = list(split_results.values())
summary_stats = {}
for metric in ["rmse", "r2", "mae", "correlation"]:
values = [s[metric] for s in split_metrics]
summary_stats[metric] = {
"mean": float(np.mean(values)),
"std": float(np.std(values)),
"min": float(np.min(values)),
"max": float(np.max(values)),
}
# 汇总结果
results = {
"data_dir": str(data_dir),
"y_col": y_col,
"n_splits": len(split_results),
"split_results": split_results,
"overall": overall_results,
"summary_stats": summary_stats,
}
# 打印结果
logger.info("\n" + "=" * 60)
logger.info("CROSS-VALIDATION EVALUATION RESULTS")
logger.info("=" * 60)
logger.info(f"\n[Per-Split Results]")
for split_name, metrics in sorted(split_results.items()):
logger.info(
f" {split_name}: RMSE={metrics['rmse']:.4f}, "
f"R²={metrics['r2']:.4f}, "
f"MAE={metrics['mae']:.4f}, "
f"Corr={metrics['correlation']:.4f}"
)
logger.info(f"\n[Summary Statistics (across {len(split_results)} splits)]")
for metric, stats in summary_stats.items():
logger.info(
f" {metric.upper():12s}: "
f"mean={stats['mean']:.4f} ± {stats['std']:.4f} "
f"(min={stats['min']:.4f}, max={stats['max']:.4f})"
)
logger.info(f"\n[Overall (all {overall_results['n_samples']} samples pooled)]")
logger.info(f" RMSE: {overall_results['rmse']:.4f}")
logger.info(f" R²: {overall_results['r2']:.4f}")
logger.info(f" MAE: {overall_results['mae']:.4f}")
logger.info(f" Correlation: {overall_results['correlation']:.4f}")
# 保存结果
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
logger.success(f"\nSaved results to {output_path}")
if __name__ == "__main__":
app()