Add CV results

This commit is contained in:
RYDE-WORK 2026-01-21 22:57:44 +08:00
parent e1c85c83ba
commit c392b48994
6 changed files with 108 additions and 38 deletions

View File

@ -86,50 +86,54 @@ MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,)
# 例如make finetune FREEZE_BACKBONE=1
FREEZE_FLAG = $(if $(FREEZE_BACKBONE),--freeze-backbone,)
# 设备选择:使用 DEVICE=xxx 指定设备
# 例如make train DEVICE=cuda:0 或 make test_cv DEVICE=mps
DEVICE_FLAG = $(if $(DEVICE),--device $(DEVICE),)
## Pretrain on external data (delivery only)
.PHONY: pretrain
pretrain: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate pretrain model (delivery metrics)
.PHONY: test_pretrain
test_pretrain: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG) $(DEVICE_FLAG)
## Pretrain with cross-validation (5-fold)
.PHONY: pretrain_cv
pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate CV pretrain models on test sets
## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_cv
test_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(MPNN_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG)
## Train model (multi-task, from scratch)
.PHONY: train
train: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train $(MPNN_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train $(MPNN_FLAG) $(DEVICE_FLAG)
## Finetune from pretrained checkpoint (use FREEZE_BACKBONE=1 to freeze backbone)
.PHONY: finetune
finetune: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Train with hyperparameter tuning
.PHONY: tune
tune: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG) $(DEVICE_FLAG)
## Run predictions
.PHONY: predict
predict: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict $(DEVICE_FLAG)
## Test model on test set (with detailed metrics)
## Test model on test set (with detailed metrics, auto-detects MPNN from checkpoint)
.PHONY: test
test: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict test
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict test $(DEVICE_FLAG)
#################################################################################

View File

@ -1,5 +1,5 @@
{
"data_dir": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/data/external/all_amine_split_for_LiON",
"data_dir": "/Users/ryde/Code/lnp_ml/data/external/all_amine_split_for_LiON",
"y_col": "quantified_delivery",
"n_splits": 5,
"split_results": {

View File

@ -1,37 +1,37 @@
{
"loss_metrics": {
"loss": 2.5374555587768555,
"loss_size": 0.1886825958887736,
"loss_pdi": 0.45798932512601215,
"loss_ee": 0.829658567905426,
"loss_delivery": 0.4857304096221924,
"loss_biodist": 0.5346279243628184,
"loss_toxic": 0.04076674363265435,
"acc_pdi": 0.7862595419847328,
"acc_ee": 0.6793893129770993,
"acc_toxic": 0.9801980198019802
"loss": 2.8661977450052896,
"loss_size": 0.44916408757368725,
"loss_pdi": 0.5041926403840383,
"loss_ee": 0.9021427234013876,
"loss_delivery": 0.5761533578236898,
"loss_biodist": 0.4019051690896352,
"loss_toxic": 0.03263980595511384,
"acc_pdi": 0.7633587786259542,
"acc_ee": 0.6641221374045801,
"acc_toxic": 0.9702970297029703
},
"detailed_metrics": {
"size": {
"mse": 0.1669999969286325,
"rmse": 0.4086563310761654,
"mae": 0.26111859684375066,
"r2": 0.2149270281561566
"mse": 0.41126506251447736,
"rmse": 0.6412995107704959,
"mae": 0.41415552388095633,
"r2": -0.9333718010891026
},
"delivery": {
"mse": 0.5193460523366603,
"rmse": 0.7206566813238189,
"mae": 0.4828052782115008,
"r2": 0.37299826459145
"mse": 0.6277965050686476,
"rmse": 0.7923361061245711,
"mae": 0.5387302115022443,
"r2": 0.24206702565575944
},
"pdi": {
"accuracy": 0.7862595419847328
"accuracy": 0.7633587786259542
},
"ee": {
"accuracy": 0.6793893129770993
"accuracy": 0.6641221374045801
},
"toxic": {
"accuracy": 0.9801980198019802
"accuracy": 0.9702970297029703
}
}
}

View File

@ -43,7 +43,7 @@ def load_model(
use_mpnn = config.get("use_mpnn", False)
if use_mpnn:
# 自动查找 MPNN ensemble
# 总是自动查找 MPNN ensemble,避免使用 checkpoint 中的旧绝对路径(可能来自其他机器)
logger.info("Model was trained with MPNN, auto-detecting ensemble...")
ensemble_paths = find_mpnn_ensemble_paths()
logger.info(f"Found {len(ensemble_paths)} MPNN models")

View File

@ -491,7 +491,7 @@ def test(
all_preds = []
all_targets = []
for fold_dir in fold_dirs:
for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"):
fold_idx = int(fold_dir.name.split("_")[1])
model_path = model_dir / f"fold_{fold_idx}" / "model.pt"
test_path = fold_dir / "test.parquet"
@ -509,10 +509,12 @@ def test(
config = checkpoint["config"]
use_mpnn = config.get("use_mpnn", False)
mpnn_paths = config.get("mpnn_ensemble_paths")
if use_mpnn and not mpnn_paths:
# 总是重新查找 MPNN 路径,避免使用 checkpoint 中的旧绝对路径(可能来自其他机器)
if use_mpnn:
mpnn_paths = find_mpnn_ensemble_paths()
else:
mpnn_paths = None
model = create_model(
d_model=config["d_model"],
@ -541,7 +543,8 @@ def test(
fold_targets = []
with torch.no_grad():
for batch in test_loader:
pbar = tqdm(test_loader, desc=f"Fold {fold_idx} [Test]", leave=False)
for batch in pbar:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = batch["targets"]["delivery"].to(device)

View File

@ -0,0 +1,63 @@
{
"fold_results": [
{
"fold_idx": 0,
"n_samples": 2037,
"mse": 0.704966583529581,
"rmse": 0.8396228817329724,
"mae": 0.6275912649623819,
"r2": 0.18947693925460807,
"correlation": 0.4894546998045728
},
{
"fold_idx": 1,
"n_samples": 1658,
"mse": 0.9115043728317782,
"rmse": 0.9547273814193129,
"mae": 0.7232097773077567,
"r2": 0.21745691549383905,
"correlation": 0.48734383353129934
},
{
"fold_idx": 2,
"n_samples": 1615,
"mse": 0.8047081385461913,
"rmse": 0.8970552594718963,
"mae": 0.6672816107495062,
"r2": 0.18280256744724777,
"correlation": 0.506874931057286
},
{
"fold_idx": 3,
"n_samples": 1754,
"mse": 0.8592105232097863,
"rmse": 0.926936094458397,
"mae": 0.6835461141701688,
"r2": 0.1746865688538688,
"correlation": 0.4233578915460334
},
{
"fold_idx": 4,
"n_samples": 1520,
"mse": 0.6832138393192553,
"rmse": 0.8265675043934738,
"mae": 0.6324079235422495,
"r2": 0.2766966748218629,
"correlation": 0.5350317708563872
}
],
"summary_stats": {
"rmse_mean": 0.8889818242952104,
"rmse_std": 0.04931538867410365,
"r2_mean": 0.2082239331742853,
"r2_std": 0.03713816024862206
},
"overall": {
"n_samples": 8584,
"mse": 0.7912902048033758,
"rmse": 0.8895449425427452,
"mae": 0.6658137170204776,
"r2": 0.20870978896035208,
"correlation": 0.4812266887089009
}
}