diff --git a/Makefile b/Makefile index 01e62d0..f0abbf9 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,12 @@ FREEZE_FLAG = $(if $(FREEZE_BACKBONE),--freeze-backbone,) ## Pretrain on external data (delivery only) .PHONY: pretrain pretrain: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain $(MPNN_FLAG) + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG) + +## Evaluate pretrain model (delivery metrics) +.PHONY: test_pretrain +test_pretrain: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG) ## Train model (multi-task, from scratch) .PHONY: train diff --git a/lnp_ml/modeling/pretrain.py b/lnp_ml/modeling/pretrain.py index 22dddf1..0dbec87 100644 --- a/lnp_ml/modeling/pretrain.py +++ b/lnp_ml/modeling/pretrain.py @@ -356,6 +356,155 @@ def main( ) +@app.command() +def test( + # 数据路径 + val_path: Path = PROCESSED_DATA_DIR / "val_pretrain.parquet", + model_path: Path = MODELS_DIR / "pretrain_delivery.pt", + output_path: Path = MODELS_DIR / "pretrain_test_results.json", + # MPNN 参数 + use_mpnn: bool = False, + mpnn_device: str = "cpu", + # 其他参数 + batch_size: int = 64, + device: str = "cuda" if torch.cuda.is_available() else "cpu", +): + """ + 评估 pretrain 模型在外部数据上的 delivery 预测性能。 + + 输出详细指标:MSE, RMSE, MAE, R² + """ + import numpy as np + from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score + + logger.info(f"Using device: {device}") + device_obj = torch.device(device) + + # 加载模型 + logger.info(f"Loading pretrain model from {model_path}") + checkpoint = torch.load(model_path, map_location=device_obj) + config = checkpoint["config"] + + # 解析 MPNN 配置 + enable_mpnn = config.get("use_mpnn", False) + if enable_mpnn or use_mpnn: + logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}") + ensemble_paths = find_mpnn_ensemble_paths() + logger.info(f"Found {len(ensemble_paths)} MPNN models") + + model = LNPModel( + d_model=config["d_model"], + num_heads=config["num_heads"], + n_attn_layers=config["n_attn_layers"], + fusion_strategy=config["fusion_strategy"], + head_hidden_dim=config["head_hidden_dim"], + dropout=config["dropout"], + mpnn_ensemble_paths=ensemble_paths, + mpnn_device=mpnn_device, + ) + else: + model = LNPModelWithoutMPNN( + d_model=config["d_model"], + num_heads=config["num_heads"], + n_attn_layers=config["n_attn_layers"], + fusion_strategy=config["fusion_strategy"], + head_hidden_dim=config["head_hidden_dim"], + dropout=config["dropout"], + ) + + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(device_obj) + model.eval() + + logger.info(f"Model config: {config}") + logger.info(f"Best val_loss from training: {checkpoint.get('best_val_loss', 'N/A')}") + + # 加载数据 + logger.info(f"Loading validation data from {val_path}") + val_df = pd.read_parquet(val_path) + val_dataset = ExternalDeliveryDataset(val_df) + val_loader = DataLoader( + val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn + ) + logger.info(f"Validation samples: {len(val_dataset)}") + + # 预测 + logger.info("Running predictions...") + all_preds = [] + all_targets = [] + all_masks = [] + + with torch.no_grad(): + for batch in tqdm(val_loader, desc="Predicting"): + smiles = batch["smiles"] + tabular = {k: v.to(device_obj) for k, v in batch["tabular"].items()} + targets = batch["targets"]["delivery"].numpy() + mask = batch["mask"]["delivery"].numpy() + + pred = model.forward_delivery(smiles, tabular).squeeze(-1).cpu().numpy() + + all_preds.extend(pred) + all_targets.extend(targets) + all_masks.extend(mask) + + # 转为数组 + all_preds = np.array(all_preds) + all_targets = np.array(all_targets) + all_masks = np.array(all_masks, dtype=bool) + + # 只计算有效样本 + y_pred = all_preds[all_masks] + y_true = all_targets[all_masks] + + # 计算指标 + mse = float(mean_squared_error(y_true, y_pred)) + rmse = float(np.sqrt(mse)) + mae = float(mean_absolute_error(y_true, y_pred)) + r2 = float(r2_score(y_true, y_pred)) + + # 额外统计 + correlation = float(np.corrcoef(y_true, y_pred)[0, 1]) + + results = { + "model_path": str(model_path), + "val_path": str(val_path), + "n_samples": int(all_masks.sum()), + "metrics": { + "mse": mse, + "rmse": rmse, + "mae": mae, + "r2": r2, + "correlation": correlation, + }, + "statistics": { + "y_true_mean": float(y_true.mean()), + "y_true_std": float(y_true.std()), + "y_pred_mean": float(y_pred.mean()), + "y_pred_std": float(y_pred.std()), + } + } + + # 打印结果 + logger.info("\n" + "=" * 50) + logger.info("PRETRAIN MODEL EVALUATION (Delivery)") + logger.info("=" * 50) + logger.info(f"Samples: {results['n_samples']}") + logger.info("\n[Metrics]") + logger.info(f" MSE: {mse:.4f}") + logger.info(f" RMSE: {rmse:.4f}") + logger.info(f" MAE: {mae:.4f}") + logger.info(f" R²: {r2:.4f}") + logger.info(f" Correlation: {correlation:.4f}") + logger.info("\n[Statistics]") + logger.info(f" True: mean={y_true.mean():.4f}, std={y_true.std():.4f}") + logger.info(f" Pred: mean={y_pred.mean():.4f}, std={y_pred.std():.4f}") + + # 保存结果 + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + logger.success(f"\nSaved results to {output_path}") + + if __name__ == "__main__": app()