This commit is contained in:
RYDE-WORK 2026-01-22 00:24:13 +08:00
parent e123fc8f3e
commit e6a5e5495a
21 changed files with 108 additions and 212 deletions

View File

@ -74,9 +74,9 @@ data_pretrain: requirements
$(PYTHON_INTERPRETER) scripts/process_external.py $(PYTHON_INTERPRETER) scripts/process_external.py
## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv) ## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv)
.PHONY: data_cv .PHONY: data_pretrain_cv
data_cv: requirements data_pretrain_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_data_cv.py $(PYTHON_INTERPRETER) scripts/process_external_cv.py
# MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder # MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder
# 例如make pretrain USE_MPNN=1 # 例如make pretrain USE_MPNN=1
@ -106,8 +106,8 @@ pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG) $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint) ## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_cv .PHONY: test_pretrain_cv
test_cv: requirements test_pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG) $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG)
## Train model (multi-task, from scratch) ## Train model (multi-task, from scratch)

View File

@ -271,7 +271,7 @@ def create_model(
@app.command() @app.command()
def main( def main(
data_dir: Path = PROCESSED_DATA_DIR / "cv", data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
output_dir: Path = MODELS_DIR / "pretrain_cv", output_dir: Path = MODELS_DIR / "pretrain_cv",
# 模型参数 # 模型参数
d_model: int = 256, d_model: int = 256,
@ -322,7 +322,7 @@ def main(
if not fold_dirs: if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}") logger.error(f"No fold_* directories found in {data_dir}")
logger.info("Please run 'make data_cv' first to process CV data.") logger.info("Please run 'make data_pretrain_cv' first to process CV data.")
raise typer.Exit(1) raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}") logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
@ -464,7 +464,7 @@ def main(
@app.command() @app.command()
def test( def test(
data_dir: Path = PROCESSED_DATA_DIR / "cv", data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
model_dir: Path = MODELS_DIR / "pretrain_cv", model_dir: Path = MODELS_DIR / "pretrain_cv",
output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json", output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json",
batch_size: int = 64, batch_size: int = 64,

Binary file not shown.

View File

@ -1,310 +1,206 @@
{ {
"train": [ "train": [
{ {
"loss": 0.8244676398801744, "loss": 0.7730368412685099,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.6991508170533461, "loss": 0.658895703010919,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.6388374940987616, "loss": 0.6059015260392299,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.6008581508669937, "loss": 0.5744731174349416,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.584832567446085, "loss": 0.5452056020458733,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.5481657371815157, "loss": 0.5138543470936083,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.5368926340308079, "loss": 0.4885380559178135,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.5210388793613561, "loss": 0.47587182296687974,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.49758357966374045, "loss": 0.4671051038255316,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.49256294099457043, "loss": 0.46794115915756107,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.4697267088016886, "loss": 0.4293930456997915,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.45763822707571084, "loss": 0.42624105651716415,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.4495221330627172, "loss": 0.4131358770446828,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.446159594079631, "loss": 0.3946074267790835,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.4327090857889029, "loss": 0.3898155013755344,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.4249273364101852, "loss": 0.37861797005733383,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.4216959138704459, "loss": 0.3775682858392304,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.416526201182502, "loss": 0.3800349080262064,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.40368679039741573, "loss": 0.36302345173031675,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.4051084730032182, "loss": 0.3429561740842766,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.38971701020385785, "loss": 0.3445638883004898,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.39155546386038786, "loss": 0.318970229203733,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.37976963541784114, "loss": 0.30179278279904437,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.36484339719805037, "loss": 0.2887343142006437,
"n_samples": 8721 "n_samples": 6783
}, },
{ {
"loss": 0.36232607571196496, "loss": 0.29240367556855545,
"n_samples": 8721 "n_samples": 6783
},
{
"loss": 0.3345973272380199,
"n_samples": 8721
},
{
"loss": 0.31767916518768957,
"n_samples": 8721
},
{
"loss": 0.32065429246052457,
"n_samples": 8721
},
{
"loss": 0.3171297926146043,
"n_samples": 8721
},
{
"loss": 0.3122120894173009,
"n_samples": 8721
},
{
"loss": 0.3135035038404461,
"n_samples": 8721
},
{
"loss": 0.2987745178222875,
"n_samples": 8721
},
{
"loss": 0.2914867957853393,
"n_samples": 8721
},
{
"loss": 0.2983839795507705,
"n_samples": 8721
},
{
"loss": 0.2826709597875678,
"n_samples": 8721
},
{
"loss": 0.2731766632569382,
"n_samples": 8721
},
{
"loss": 0.27726896305742266,
"n_samples": 8721
},
{
"loss": 0.27864557847067956,
"n_samples": 8721
} }
], ],
"val": [ "val": [
{ {
"loss": 0.7601077516012517, "loss": 0.7350345371841441,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.7119935319611901, "loss": 0.7165568811318536,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.6461842978148269, "loss": 0.7251406249862214,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.7006978391063226, "loss": 0.6836505264587159,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.6533874032943979, "loss": 0.6747132955771933,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.6413641451743611, "loss": 0.6691136244936912,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.6168395132979742, "loss": 0.6337480902323249,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.6095251602162025, "loss": 0.6600317959527934,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5887809592626905, "loss": 0.6439923948855346,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5655298325376368, "loss": 0.643800035575267,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5809201743872788, "loss": 0.6181512585221839,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5897585974912033, "loss": 0.6442458634939151,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5732012489662573, "loss": 0.6344759362359862,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5607388911786094, "loss": 0.6501405371457472,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5717580675371414, "loss": 0.6098835162990152,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5553950037657291, "loss": 0.6366627322138894,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5778171792857049, "loss": 0.6171610150646417,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5602665468127734, "loss": 0.6358801012273748,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5475307451359259, "loss": 0.6239976831059871,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.551515599314827, "loss": 0.6683828232827201,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5755438121541243, "loss": 0.6655785786478143,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5798238261811381, "loss": 0.6152775046503088,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5739961433828923, "loss": 0.6202247662153858,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5742599932540312, "loss": 0.648199727435189,
"n_samples": 969 "n_samples": 2907
}, },
{ {
"loss": 0.5834948123885382, "loss": 0.6473217075085124,
"n_samples": 969 "n_samples": 2907
},
{
"loss": 0.554078846570139,
"n_samples": 969
},
{
"loss": 0.5714933996322354,
"n_samples": 969
},
{
"loss": 0.5384107524350331,
"n_samples": 969
},
{
"loss": 0.570854394451568,
"n_samples": 969
},
{
"loss": 0.5767292551642478,
"n_samples": 969
},
{
"loss": 0.5660079547556808,
"n_samples": 969
},
{
"loss": 0.5608972411514312,
"n_samples": 969
},
{
"loss": 0.5620947442987263,
"n_samples": 969
},
{
"loss": 0.5706970894361305,
"n_samples": 969
},
{
"loss": 0.5702376298690974,
"n_samples": 969
},
{
"loss": 0.5758474825259579,
"n_samples": 969
},
{
"loss": 0.5673816067284844,
"n_samples": 969
},
{
"loss": 0.5671441179879925,
"n_samples": 969
} }
] ]
} }