mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
Add CV results
This commit is contained in:
parent
e1c85c83ba
commit
c392b48994
26
Makefile
26
Makefile
@ -86,50 +86,54 @@ MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,)
|
||||
# 例如:make finetune FREEZE_BACKBONE=1
|
||||
FREEZE_FLAG = $(if $(FREEZE_BACKBONE),--freeze-backbone,)
|
||||
|
||||
# 设备选择:使用 DEVICE=xxx 指定设备
|
||||
# 例如:make train DEVICE=cuda:0 或 make test_cv DEVICE=mps
|
||||
DEVICE_FLAG = $(if $(DEVICE),--device $(DEVICE),)
|
||||
|
||||
## Pretrain on external data (delivery only)
|
||||
.PHONY: pretrain
|
||||
pretrain: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG)
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Evaluate pretrain model (delivery metrics)
|
||||
.PHONY: test_pretrain
|
||||
test_pretrain: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG)
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Pretrain with cross-validation (5-fold)
|
||||
.PHONY: pretrain_cv
|
||||
pretrain_cv: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG)
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Evaluate CV pretrain models on test sets
|
||||
## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint)
|
||||
.PHONY: test_cv
|
||||
test_cv: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(MPNN_FLAG)
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG)
|
||||
|
||||
## Train model (multi-task, from scratch)
|
||||
.PHONY: train
|
||||
train: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train $(MPNN_FLAG)
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Finetune from pretrained checkpoint (use FREEZE_BACKBONE=1 to freeze backbone)
|
||||
.PHONY: finetune
|
||||
finetune: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG)
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Train with hyperparameter tuning
|
||||
.PHONY: tune
|
||||
tune: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG)
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Run predictions
|
||||
.PHONY: predict
|
||||
predict: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict $(DEVICE_FLAG)
|
||||
|
||||
## Test model on test set (with detailed metrics)
|
||||
## Test model on test set (with detailed metrics, auto-detects MPNN from checkpoint)
|
||||
.PHONY: test
|
||||
test: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict test
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict test $(DEVICE_FLAG)
|
||||
|
||||
|
||||
#################################################################################
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
{
|
||||
"data_dir": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/data/external/all_amine_split_for_LiON",
|
||||
"data_dir": "/Users/ryde/Code/lnp_ml/data/external/all_amine_split_for_LiON",
|
||||
"y_col": "quantified_delivery",
|
||||
"n_splits": 5,
|
||||
"split_results": {
|
||||
|
||||
@ -1,37 +1,37 @@
|
||||
{
|
||||
"loss_metrics": {
|
||||
"loss": 2.5374555587768555,
|
||||
"loss_size": 0.1886825958887736,
|
||||
"loss_pdi": 0.45798932512601215,
|
||||
"loss_ee": 0.829658567905426,
|
||||
"loss_delivery": 0.4857304096221924,
|
||||
"loss_biodist": 0.5346279243628184,
|
||||
"loss_toxic": 0.04076674363265435,
|
||||
"acc_pdi": 0.7862595419847328,
|
||||
"acc_ee": 0.6793893129770993,
|
||||
"acc_toxic": 0.9801980198019802
|
||||
"loss": 2.8661977450052896,
|
||||
"loss_size": 0.44916408757368725,
|
||||
"loss_pdi": 0.5041926403840383,
|
||||
"loss_ee": 0.9021427234013876,
|
||||
"loss_delivery": 0.5761533578236898,
|
||||
"loss_biodist": 0.4019051690896352,
|
||||
"loss_toxic": 0.03263980595511384,
|
||||
"acc_pdi": 0.7633587786259542,
|
||||
"acc_ee": 0.6641221374045801,
|
||||
"acc_toxic": 0.9702970297029703
|
||||
},
|
||||
"detailed_metrics": {
|
||||
"size": {
|
||||
"mse": 0.1669999969286325,
|
||||
"rmse": 0.4086563310761654,
|
||||
"mae": 0.26111859684375066,
|
||||
"r2": 0.2149270281561566
|
||||
"mse": 0.41126506251447736,
|
||||
"rmse": 0.6412995107704959,
|
||||
"mae": 0.41415552388095633,
|
||||
"r2": -0.9333718010891026
|
||||
},
|
||||
"delivery": {
|
||||
"mse": 0.5193460523366603,
|
||||
"rmse": 0.7206566813238189,
|
||||
"mae": 0.4828052782115008,
|
||||
"r2": 0.37299826459145
|
||||
"mse": 0.6277965050686476,
|
||||
"rmse": 0.7923361061245711,
|
||||
"mae": 0.5387302115022443,
|
||||
"r2": 0.24206702565575944
|
||||
},
|
||||
"pdi": {
|
||||
"accuracy": 0.7862595419847328
|
||||
"accuracy": 0.7633587786259542
|
||||
},
|
||||
"ee": {
|
||||
"accuracy": 0.6793893129770993
|
||||
"accuracy": 0.6641221374045801
|
||||
},
|
||||
"toxic": {
|
||||
"accuracy": 0.9801980198019802
|
||||
"accuracy": 0.9702970297029703
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -43,7 +43,7 @@ def load_model(
|
||||
use_mpnn = config.get("use_mpnn", False)
|
||||
|
||||
if use_mpnn:
|
||||
# 自动查找 MPNN ensemble
|
||||
# 总是自动查找 MPNN ensemble,避免使用 checkpoint 中的旧绝对路径(可能来自其他机器)
|
||||
logger.info("Model was trained with MPNN, auto-detecting ensemble...")
|
||||
ensemble_paths = find_mpnn_ensemble_paths()
|
||||
logger.info(f"Found {len(ensemble_paths)} MPNN models")
|
||||
|
||||
@ -491,7 +491,7 @@ def test(
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
|
||||
for fold_dir in fold_dirs:
|
||||
for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"):
|
||||
fold_idx = int(fold_dir.name.split("_")[1])
|
||||
model_path = model_dir / f"fold_{fold_idx}" / "model.pt"
|
||||
test_path = fold_dir / "test.parquet"
|
||||
@ -509,10 +509,12 @@ def test(
|
||||
config = checkpoint["config"]
|
||||
|
||||
use_mpnn = config.get("use_mpnn", False)
|
||||
mpnn_paths = config.get("mpnn_ensemble_paths")
|
||||
|
||||
if use_mpnn and not mpnn_paths:
|
||||
# 总是重新查找 MPNN 路径,避免使用 checkpoint 中的旧绝对路径(可能来自其他机器)
|
||||
if use_mpnn:
|
||||
mpnn_paths = find_mpnn_ensemble_paths()
|
||||
else:
|
||||
mpnn_paths = None
|
||||
|
||||
model = create_model(
|
||||
d_model=config["d_model"],
|
||||
@ -541,7 +543,8 @@ def test(
|
||||
fold_targets = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in test_loader:
|
||||
pbar = tqdm(test_loader, desc=f"Fold {fold_idx} [Test]", leave=False)
|
||||
for batch in pbar:
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
targets = batch["targets"]["delivery"].to(device)
|
||||
|
||||
63
models/pretrain_cv/test_results.json
Normal file
63
models/pretrain_cv/test_results.json
Normal file
@ -0,0 +1,63 @@
|
||||
{
|
||||
"fold_results": [
|
||||
{
|
||||
"fold_idx": 0,
|
||||
"n_samples": 2037,
|
||||
"mse": 0.704966583529581,
|
||||
"rmse": 0.8396228817329724,
|
||||
"mae": 0.6275912649623819,
|
||||
"r2": 0.18947693925460807,
|
||||
"correlation": 0.4894546998045728
|
||||
},
|
||||
{
|
||||
"fold_idx": 1,
|
||||
"n_samples": 1658,
|
||||
"mse": 0.9115043728317782,
|
||||
"rmse": 0.9547273814193129,
|
||||
"mae": 0.7232097773077567,
|
||||
"r2": 0.21745691549383905,
|
||||
"correlation": 0.48734383353129934
|
||||
},
|
||||
{
|
||||
"fold_idx": 2,
|
||||
"n_samples": 1615,
|
||||
"mse": 0.8047081385461913,
|
||||
"rmse": 0.8970552594718963,
|
||||
"mae": 0.6672816107495062,
|
||||
"r2": 0.18280256744724777,
|
||||
"correlation": 0.506874931057286
|
||||
},
|
||||
{
|
||||
"fold_idx": 3,
|
||||
"n_samples": 1754,
|
||||
"mse": 0.8592105232097863,
|
||||
"rmse": 0.926936094458397,
|
||||
"mae": 0.6835461141701688,
|
||||
"r2": 0.1746865688538688,
|
||||
"correlation": 0.4233578915460334
|
||||
},
|
||||
{
|
||||
"fold_idx": 4,
|
||||
"n_samples": 1520,
|
||||
"mse": 0.6832138393192553,
|
||||
"rmse": 0.8265675043934738,
|
||||
"mae": 0.6324079235422495,
|
||||
"r2": 0.2766966748218629,
|
||||
"correlation": 0.5350317708563872
|
||||
}
|
||||
],
|
||||
"summary_stats": {
|
||||
"rmse_mean": 0.8889818242952104,
|
||||
"rmse_std": 0.04931538867410365,
|
||||
"r2_mean": 0.2082239331742853,
|
||||
"r2_std": 0.03713816024862206
|
||||
},
|
||||
"overall": {
|
||||
"n_samples": 8584,
|
||||
"mse": 0.7912902048033758,
|
||||
"rmse": 0.8895449425427452,
|
||||
"mae": 0.6658137170204776,
|
||||
"r2": 0.20870978896035208,
|
||||
"correlation": 0.4812266887089009
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user