mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 17:46:39 +08:00
205 lines
6.5 KiB
Python
205 lines
6.5 KiB
Python
"""评估外部数据 cross-validation 结果的脚本"""
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Dict
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import typer
|
||
from loguru import logger
|
||
from sklearn.metrics import mean_squared_error, r2_score
|
||
|
||
from lnp_ml.config import EXTERNAL_DATA_DIR
|
||
|
||
|
||
app = typer.Typer()
|
||
|
||
|
||
def evaluate_split(
|
||
test_path: Path,
|
||
preds_path: Path,
|
||
y_col: str = "quantified_delivery",
|
||
) -> Dict[str, float]:
|
||
"""
|
||
评估单个 split 的预测结果。
|
||
|
||
Args:
|
||
test_path: groundtruth CSV 路径(包含 smiles 和 y_col)
|
||
preds_path: 预测结果 CSV 路径(包含 smiles 和 y_col)
|
||
y_col: Y 值列名
|
||
|
||
Returns:
|
||
Dict with rmse, r2, mse, mae, n_samples
|
||
"""
|
||
# 读取数据
|
||
test_df = pd.read_csv(test_path)
|
||
preds_df = pd.read_csv(preds_path)
|
||
|
||
# 检查列是否存在
|
||
if y_col not in test_df.columns:
|
||
raise ValueError(f"Column '{y_col}' not found in {test_path}")
|
||
if y_col not in preds_df.columns:
|
||
raise ValueError(f"Column '{y_col}' not found in {preds_path}")
|
||
|
||
# 直接按行对齐(test.csv 和 preds.csv 顺序一致)
|
||
assert len(test_df) == len(preds_df), f"行数不匹配: test={len(test_df)}, preds={len(preds_df)}"
|
||
|
||
# 验证 smiles 列一致(如果存在)
|
||
if "smiles" in test_df.columns and "smiles" in preds_df.columns:
|
||
if not (test_df["smiles"].values == preds_df["smiles"].values).all():
|
||
logger.warning("smiles 列顺序不一致,请检查数据")
|
||
|
||
y_true = test_df[y_col].values
|
||
y_pred = preds_df[y_col].values
|
||
|
||
# 计算指标
|
||
mse = float(mean_squared_error(y_true, y_pred))
|
||
rmse = float(np.sqrt(mse))
|
||
r2 = float(r2_score(y_true, y_pred))
|
||
mae = float(np.mean(np.abs(y_true - y_pred)))
|
||
correlation = float(np.corrcoef(y_true, y_pred)[0, 1])
|
||
|
||
return {
|
||
"n_samples": len(y_true),
|
||
"mse": mse,
|
||
"rmse": rmse,
|
||
"mae": mae,
|
||
"r2": r2,
|
||
"correlation": correlation,
|
||
}
|
||
|
||
|
||
@app.command()
|
||
def main(
|
||
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
|
||
output_path: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON" / "evaluation_results.json",
|
||
y_col: str = "quantified_delivery",
|
||
):
|
||
"""
|
||
评估 all_amine_split_for_LiON 的 cross-validation 结果。
|
||
|
||
计算每个 split 和整体的 RMSE、R² 指标。
|
||
"""
|
||
logger.info(f"Evaluating cross-validation results from {data_dir}")
|
||
|
||
# 找到所有 cv_* 目录
|
||
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
|
||
|
||
if not cv_dirs:
|
||
logger.error(f"No cv_* directories found in {data_dir}")
|
||
raise typer.Exit(1)
|
||
|
||
logger.info(f"Found {len(cv_dirs)} splits: {[d.name for d in cv_dirs]}")
|
||
|
||
# 评估每个 split
|
||
split_results = {}
|
||
all_y_true = []
|
||
all_y_pred = []
|
||
|
||
for cv_dir in cv_dirs:
|
||
split_name = cv_dir.name
|
||
test_path = cv_dir / "test.csv"
|
||
preds_path = cv_dir / "preds.csv"
|
||
|
||
if not test_path.exists():
|
||
logger.warning(f" {split_name}: test.csv not found, skipping")
|
||
continue
|
||
if not preds_path.exists():
|
||
logger.warning(f" {split_name}: preds.csv not found, skipping")
|
||
continue
|
||
|
||
# 评估
|
||
metrics = evaluate_split(test_path, preds_path, y_col)
|
||
split_results[split_name] = metrics
|
||
|
||
logger.info(
|
||
f" {split_name}: n={metrics['n_samples']}, "
|
||
f"RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}"
|
||
)
|
||
|
||
# 收集所有数据用于计算整体指标
|
||
test_df = pd.read_csv(test_path)
|
||
preds_df = pd.read_csv(preds_path)
|
||
all_y_true.extend(test_df[y_col].tolist())
|
||
all_y_pred.extend(preds_df[y_col].tolist())
|
||
|
||
# 计算整体指标
|
||
all_y_true = np.array(all_y_true)
|
||
all_y_pred = np.array(all_y_pred)
|
||
|
||
overall_mse = float(mean_squared_error(all_y_true, all_y_pred))
|
||
overall_rmse = float(np.sqrt(overall_mse))
|
||
overall_r2 = float(r2_score(all_y_true, all_y_pred))
|
||
overall_mae = float(np.mean(np.abs(all_y_true - all_y_pred)))
|
||
overall_correlation = float(np.corrcoef(all_y_true, all_y_pred)[0, 1])
|
||
|
||
overall_results = {
|
||
"n_samples": len(all_y_true),
|
||
"mse": overall_mse,
|
||
"rmse": overall_rmse,
|
||
"mae": overall_mae,
|
||
"r2": overall_r2,
|
||
"correlation": overall_correlation,
|
||
}
|
||
|
||
# 计算 split 指标的均值和标准差
|
||
split_metrics = list(split_results.values())
|
||
summary_stats = {}
|
||
for metric in ["rmse", "r2", "mae", "correlation"]:
|
||
values = [s[metric] for s in split_metrics]
|
||
summary_stats[metric] = {
|
||
"mean": float(np.mean(values)),
|
||
"std": float(np.std(values)),
|
||
"min": float(np.min(values)),
|
||
"max": float(np.max(values)),
|
||
}
|
||
|
||
# 汇总结果
|
||
results = {
|
||
"data_dir": str(data_dir),
|
||
"y_col": y_col,
|
||
"n_splits": len(split_results),
|
||
"split_results": split_results,
|
||
"overall": overall_results,
|
||
"summary_stats": summary_stats,
|
||
}
|
||
|
||
# 打印结果
|
||
logger.info("\n" + "=" * 60)
|
||
logger.info("CROSS-VALIDATION EVALUATION RESULTS")
|
||
logger.info("=" * 60)
|
||
|
||
logger.info(f"\n[Per-Split Results]")
|
||
for split_name, metrics in sorted(split_results.items()):
|
||
logger.info(
|
||
f" {split_name}: RMSE={metrics['rmse']:.4f}, "
|
||
f"R²={metrics['r2']:.4f}, "
|
||
f"MAE={metrics['mae']:.4f}, "
|
||
f"Corr={metrics['correlation']:.4f}"
|
||
)
|
||
|
||
logger.info(f"\n[Summary Statistics (across {len(split_results)} splits)]")
|
||
for metric, stats in summary_stats.items():
|
||
logger.info(
|
||
f" {metric.upper():12s}: "
|
||
f"mean={stats['mean']:.4f} ± {stats['std']:.4f} "
|
||
f"(min={stats['min']:.4f}, max={stats['max']:.4f})"
|
||
)
|
||
|
||
logger.info(f"\n[Overall (all {overall_results['n_samples']} samples pooled)]")
|
||
logger.info(f" RMSE: {overall_results['rmse']:.4f}")
|
||
logger.info(f" R²: {overall_results['r2']:.4f}")
|
||
logger.info(f" MAE: {overall_results['mae']:.4f}")
|
||
logger.info(f" Correlation: {overall_results['correlation']:.4f}")
|
||
|
||
# 保存结果
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_path, "w") as f:
|
||
json.dump(results, f, indent=2)
|
||
logger.success(f"\nSaved results to {output_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app()
|