From c392b4899457fcedd91f7fff87d575183a9462fc Mon Sep 17 00:00:00 2001 From: RYDE-WORK Date: Wed, 21 Jan 2026 22:57:44 +0800 Subject: [PATCH] Add CV results --- Makefile | 26 ++++---- .../evaluation_results.json | 2 +- data/processed/test_results.json | 42 ++++++------- lnp_ml/modeling/predict.py | 2 +- lnp_ml/modeling/pretrain_cv.py | 11 ++-- models/pretrain_cv/test_results.json | 63 +++++++++++++++++++ 6 files changed, 108 insertions(+), 38 deletions(-) create mode 100644 models/pretrain_cv/test_results.json diff --git a/Makefile b/Makefile index b064847..b3c6b13 100644 --- a/Makefile +++ b/Makefile @@ -86,50 +86,54 @@ MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,) # 例如:make finetune FREEZE_BACKBONE=1 FREEZE_FLAG = $(if $(FREEZE_BACKBONE),--freeze-backbone,) +# 设备选择:使用 DEVICE=xxx 指定设备 +# 例如:make train DEVICE=cuda:0 或 make test_cv DEVICE=mps +DEVICE_FLAG = $(if $(DEVICE),--device $(DEVICE),) + ## Pretrain on external data (delivery only) .PHONY: pretrain pretrain: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG) + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG) $(DEVICE_FLAG) ## Evaluate pretrain model (delivery metrics) .PHONY: test_pretrain test_pretrain: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG) + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG) $(DEVICE_FLAG) ## Pretrain with cross-validation (5-fold) .PHONY: pretrain_cv pretrain_cv: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG) -## Evaluate CV pretrain models on test sets +## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint) .PHONY: test_cv test_cv: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(MPNN_FLAG) + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG) ## Train model (multi-task, from scratch) .PHONY: train train: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train $(MPNN_FLAG) + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train $(MPNN_FLAG) $(DEVICE_FLAG) ## Finetune from pretrained checkpoint (use FREEZE_BACKBONE=1 to freeze backbone) .PHONY: finetune finetune: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG) ## Train with hyperparameter tuning .PHONY: tune tune: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG) + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG) $(DEVICE_FLAG) ## Run predictions .PHONY: predict predict: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict $(DEVICE_FLAG) -## Test model on test set (with detailed metrics) +## Test model on test set (with detailed metrics, auto-detects MPNN from checkpoint) .PHONY: test test: requirements - $(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict test + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict test $(DEVICE_FLAG) ################################################################################# diff --git a/data/external/all_amine_split_for_LiON/evaluation_results.json b/data/external/all_amine_split_for_LiON/evaluation_results.json index b038ed7..c94ff52 100644 --- a/data/external/all_amine_split_for_LiON/evaluation_results.json +++ b/data/external/all_amine_split_for_LiON/evaluation_results.json @@ -1,5 +1,5 @@ { - "data_dir": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/data/external/all_amine_split_for_LiON", + "data_dir": "/Users/ryde/Code/lnp_ml/data/external/all_amine_split_for_LiON", "y_col": "quantified_delivery", "n_splits": 5, "split_results": { diff --git a/data/processed/test_results.json b/data/processed/test_results.json index 5e4970d..5b23fc1 100644 --- a/data/processed/test_results.json +++ b/data/processed/test_results.json @@ -1,37 +1,37 @@ { "loss_metrics": { - "loss": 2.5374555587768555, - "loss_size": 0.1886825958887736, - "loss_pdi": 0.45798932512601215, - "loss_ee": 0.829658567905426, - "loss_delivery": 0.4857304096221924, - "loss_biodist": 0.5346279243628184, - "loss_toxic": 0.04076674363265435, - "acc_pdi": 0.7862595419847328, - "acc_ee": 0.6793893129770993, - "acc_toxic": 0.9801980198019802 + "loss": 2.8661977450052896, + "loss_size": 0.44916408757368725, + "loss_pdi": 0.5041926403840383, + "loss_ee": 0.9021427234013876, + "loss_delivery": 0.5761533578236898, + "loss_biodist": 0.4019051690896352, + "loss_toxic": 0.03263980595511384, + "acc_pdi": 0.7633587786259542, + "acc_ee": 0.6641221374045801, + "acc_toxic": 0.9702970297029703 }, "detailed_metrics": { "size": { - "mse": 0.1669999969286325, - "rmse": 0.4086563310761654, - "mae": 0.26111859684375066, - "r2": 0.2149270281561566 + "mse": 0.41126506251447736, + "rmse": 0.6412995107704959, + "mae": 0.41415552388095633, + "r2": -0.9333718010891026 }, "delivery": { - "mse": 0.5193460523366603, - "rmse": 0.7206566813238189, - "mae": 0.4828052782115008, - "r2": 0.37299826459145 + "mse": 0.6277965050686476, + "rmse": 0.7923361061245711, + "mae": 0.5387302115022443, + "r2": 0.24206702565575944 }, "pdi": { - "accuracy": 0.7862595419847328 + "accuracy": 0.7633587786259542 }, "ee": { - "accuracy": 0.6793893129770993 + "accuracy": 0.6641221374045801 }, "toxic": { - "accuracy": 0.9801980198019802 + "accuracy": 0.9702970297029703 } } } \ No newline at end of file diff --git a/lnp_ml/modeling/predict.py b/lnp_ml/modeling/predict.py index 2746dc5..1fe6b43 100644 --- a/lnp_ml/modeling/predict.py +++ b/lnp_ml/modeling/predict.py @@ -43,7 +43,7 @@ def load_model( use_mpnn = config.get("use_mpnn", False) if use_mpnn: - # 自动查找 MPNN ensemble + # 总是自动查找 MPNN ensemble,避免使用 checkpoint 中的旧绝对路径(可能来自其他机器) logger.info("Model was trained with MPNN, auto-detecting ensemble...") ensemble_paths = find_mpnn_ensemble_paths() logger.info(f"Found {len(ensemble_paths)} MPNN models") diff --git a/lnp_ml/modeling/pretrain_cv.py b/lnp_ml/modeling/pretrain_cv.py index 357382d..33bd9b8 100644 --- a/lnp_ml/modeling/pretrain_cv.py +++ b/lnp_ml/modeling/pretrain_cv.py @@ -491,7 +491,7 @@ def test( all_preds = [] all_targets = [] - for fold_dir in fold_dirs: + for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"): fold_idx = int(fold_dir.name.split("_")[1]) model_path = model_dir / f"fold_{fold_idx}" / "model.pt" test_path = fold_dir / "test.parquet" @@ -509,10 +509,12 @@ def test( config = checkpoint["config"] use_mpnn = config.get("use_mpnn", False) - mpnn_paths = config.get("mpnn_ensemble_paths") - if use_mpnn and not mpnn_paths: + # 总是重新查找 MPNN 路径,避免使用 checkpoint 中的旧绝对路径(可能来自其他机器) + if use_mpnn: mpnn_paths = find_mpnn_ensemble_paths() + else: + mpnn_paths = None model = create_model( d_model=config["d_model"], @@ -541,7 +543,8 @@ def test( fold_targets = [] with torch.no_grad(): - for batch in test_loader: + pbar = tqdm(test_loader, desc=f"Fold {fold_idx} [Test]", leave=False) + for batch in pbar: smiles = batch["smiles"] tabular = {k: v.to(device) for k, v in batch["tabular"].items()} targets = batch["targets"]["delivery"].to(device) diff --git a/models/pretrain_cv/test_results.json b/models/pretrain_cv/test_results.json new file mode 100644 index 0000000..6d56694 --- /dev/null +++ b/models/pretrain_cv/test_results.json @@ -0,0 +1,63 @@ +{ + "fold_results": [ + { + "fold_idx": 0, + "n_samples": 2037, + "mse": 0.704966583529581, + "rmse": 0.8396228817329724, + "mae": 0.6275912649623819, + "r2": 0.18947693925460807, + "correlation": 0.4894546998045728 + }, + { + "fold_idx": 1, + "n_samples": 1658, + "mse": 0.9115043728317782, + "rmse": 0.9547273814193129, + "mae": 0.7232097773077567, + "r2": 0.21745691549383905, + "correlation": 0.48734383353129934 + }, + { + "fold_idx": 2, + "n_samples": 1615, + "mse": 0.8047081385461913, + "rmse": 0.8970552594718963, + "mae": 0.6672816107495062, + "r2": 0.18280256744724777, + "correlation": 0.506874931057286 + }, + { + "fold_idx": 3, + "n_samples": 1754, + "mse": 0.8592105232097863, + "rmse": 0.926936094458397, + "mae": 0.6835461141701688, + "r2": 0.1746865688538688, + "correlation": 0.4233578915460334 + }, + { + "fold_idx": 4, + "n_samples": 1520, + "mse": 0.6832138393192553, + "rmse": 0.8265675043934738, + "mae": 0.6324079235422495, + "r2": 0.2766966748218629, + "correlation": 0.5350317708563872 + } + ], + "summary_stats": { + "rmse_mean": 0.8889818242952104, + "rmse_std": 0.04931538867410365, + "r2_mean": 0.2082239331742853, + "r2_std": 0.03713816024862206 + }, + "overall": { + "n_samples": 8584, + "mse": 0.7912902048033758, + "rmse": 0.8895449425427452, + "mae": 0.6658137170204776, + "r2": 0.20870978896035208, + "correlation": 0.4812266887089009 + } +} \ No newline at end of file