diff --git a/Makefile b/Makefile index b3c6b13..e73b512 100644 --- a/Makefile +++ b/Makefile @@ -74,9 +74,9 @@ data_pretrain: requirements $(PYTHON_INTERPRETER) scripts/process_external.py ## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv) -.PHONY: data_cv -data_cv: requirements - $(PYTHON_INTERPRETER) scripts/process_data_cv.py +.PHONY: data_pretrain_cv +data_pretrain_cv: requirements + $(PYTHON_INTERPRETER) scripts/process_external_cv.py # MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder # 例如:make pretrain USE_MPNN=1 @@ -106,8 +106,8 @@ pretrain_cv: requirements $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG) ## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint) -.PHONY: test_cv -test_cv: requirements +.PHONY: test_pretrain_cv +test_pretrain_cv: requirements $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG) ## Train model (multi-task, from scratch) diff --git a/data/processed/cv/feature_columns.txt b/data/processed/pretrain_cv/feature_columns.txt similarity index 100% rename from data/processed/cv/feature_columns.txt rename to data/processed/pretrain_cv/feature_columns.txt diff --git a/data/processed/cv/fold_0/test.parquet b/data/processed/pretrain_cv/fold_0/test.parquet similarity index 100% rename from data/processed/cv/fold_0/test.parquet rename to data/processed/pretrain_cv/fold_0/test.parquet diff --git a/data/processed/cv/fold_0/train.parquet b/data/processed/pretrain_cv/fold_0/train.parquet similarity index 100% rename from data/processed/cv/fold_0/train.parquet rename to data/processed/pretrain_cv/fold_0/train.parquet diff --git a/data/processed/cv/fold_0/valid.parquet b/data/processed/pretrain_cv/fold_0/valid.parquet similarity index 100% rename from data/processed/cv/fold_0/valid.parquet rename to data/processed/pretrain_cv/fold_0/valid.parquet diff --git a/data/processed/cv/fold_1/test.parquet b/data/processed/pretrain_cv/fold_1/test.parquet similarity index 100% rename from data/processed/cv/fold_1/test.parquet rename to data/processed/pretrain_cv/fold_1/test.parquet diff --git a/data/processed/cv/fold_1/train.parquet b/data/processed/pretrain_cv/fold_1/train.parquet similarity index 100% rename from data/processed/cv/fold_1/train.parquet rename to data/processed/pretrain_cv/fold_1/train.parquet diff --git a/data/processed/cv/fold_1/valid.parquet b/data/processed/pretrain_cv/fold_1/valid.parquet similarity index 100% rename from data/processed/cv/fold_1/valid.parquet rename to data/processed/pretrain_cv/fold_1/valid.parquet diff --git a/data/processed/cv/fold_2/test.parquet b/data/processed/pretrain_cv/fold_2/test.parquet similarity index 100% rename from data/processed/cv/fold_2/test.parquet rename to data/processed/pretrain_cv/fold_2/test.parquet diff --git a/data/processed/cv/fold_2/train.parquet b/data/processed/pretrain_cv/fold_2/train.parquet similarity index 100% rename from data/processed/cv/fold_2/train.parquet rename to data/processed/pretrain_cv/fold_2/train.parquet diff --git a/data/processed/cv/fold_2/valid.parquet b/data/processed/pretrain_cv/fold_2/valid.parquet similarity index 100% rename from data/processed/cv/fold_2/valid.parquet rename to data/processed/pretrain_cv/fold_2/valid.parquet diff --git a/data/processed/cv/fold_3/test.parquet b/data/processed/pretrain_cv/fold_3/test.parquet similarity index 100% rename from data/processed/cv/fold_3/test.parquet rename to data/processed/pretrain_cv/fold_3/test.parquet diff --git a/data/processed/cv/fold_3/train.parquet b/data/processed/pretrain_cv/fold_3/train.parquet similarity index 100% rename from data/processed/cv/fold_3/train.parquet rename to data/processed/pretrain_cv/fold_3/train.parquet diff --git a/data/processed/cv/fold_3/valid.parquet b/data/processed/pretrain_cv/fold_3/valid.parquet similarity index 100% rename from data/processed/cv/fold_3/valid.parquet rename to data/processed/pretrain_cv/fold_3/valid.parquet diff --git a/data/processed/cv/fold_4/test.parquet b/data/processed/pretrain_cv/fold_4/test.parquet similarity index 100% rename from data/processed/cv/fold_4/test.parquet rename to data/processed/pretrain_cv/fold_4/test.parquet diff --git a/data/processed/cv/fold_4/train.parquet b/data/processed/pretrain_cv/fold_4/train.parquet similarity index 100% rename from data/processed/cv/fold_4/train.parquet rename to data/processed/pretrain_cv/fold_4/train.parquet diff --git a/data/processed/cv/fold_4/valid.parquet b/data/processed/pretrain_cv/fold_4/valid.parquet similarity index 100% rename from data/processed/cv/fold_4/valid.parquet rename to data/processed/pretrain_cv/fold_4/valid.parquet diff --git a/lnp_ml/modeling/pretrain_cv.py b/lnp_ml/modeling/pretrain_cv.py index 33bd9b8..75188f9 100644 --- a/lnp_ml/modeling/pretrain_cv.py +++ b/lnp_ml/modeling/pretrain_cv.py @@ -271,7 +271,7 @@ def create_model( @app.command() def main( - data_dir: Path = PROCESSED_DATA_DIR / "cv", + data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv", output_dir: Path = MODELS_DIR / "pretrain_cv", # 模型参数 d_model: int = 256, @@ -322,7 +322,7 @@ def main( if not fold_dirs: logger.error(f"No fold_* directories found in {data_dir}") - logger.info("Please run 'make data_cv' first to process CV data.") + logger.info("Please run 'make data_pretrain_cv' first to process CV data.") raise typer.Exit(1) logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}") @@ -464,7 +464,7 @@ def main( @app.command() def test( - data_dir: Path = PROCESSED_DATA_DIR / "cv", + data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv", model_dir: Path = MODELS_DIR / "pretrain_cv", output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json", batch_size: int = 64, diff --git a/models/pretrain_delivery.pt b/models/pretrain_delivery.pt index 3675a40..bc73785 100644 Binary files a/models/pretrain_delivery.pt and b/models/pretrain_delivery.pt differ diff --git a/models/pretrain_history.json b/models/pretrain_history.json index be21bf2..fe5e2a1 100644 --- a/models/pretrain_history.json +++ b/models/pretrain_history.json @@ -1,310 +1,206 @@ { "train": [ { - "loss": 0.8244676398801744, - "n_samples": 8721 + "loss": 0.7730368412685099, + "n_samples": 6783 }, { - "loss": 0.6991508170533461, - "n_samples": 8721 + "loss": 0.658895703010919, + "n_samples": 6783 }, { - "loss": 0.6388374940987616, - "n_samples": 8721 + "loss": 0.6059015260392299, + "n_samples": 6783 }, { - "loss": 0.6008581508669937, - "n_samples": 8721 + "loss": 0.5744731174349416, + "n_samples": 6783 }, { - "loss": 0.584832567446085, - "n_samples": 8721 + "loss": 0.5452056020458733, + "n_samples": 6783 }, { - "loss": 0.5481657371815157, - "n_samples": 8721 + "loss": 0.5138543470936083, + "n_samples": 6783 }, { - "loss": 0.5368926340308079, - "n_samples": 8721 + "loss": 0.4885380559178135, + "n_samples": 6783 }, { - "loss": 0.5210388793613561, - "n_samples": 8721 + "loss": 0.47587182296687974, + "n_samples": 6783 }, { - "loss": 0.49758357966374045, - "n_samples": 8721 + "loss": 0.4671051038255316, + "n_samples": 6783 }, { - "loss": 0.49256294099457043, - "n_samples": 8721 + "loss": 0.46794115915756107, + "n_samples": 6783 }, { - "loss": 0.4697267088016886, - "n_samples": 8721 + "loss": 0.4293930456997915, + "n_samples": 6783 }, { - "loss": 0.45763822707571084, - "n_samples": 8721 + "loss": 0.42624105651716415, + "n_samples": 6783 }, { - "loss": 0.4495221330627172, - "n_samples": 8721 + "loss": 0.4131358770446828, + "n_samples": 6783 }, { - "loss": 0.446159594079631, - "n_samples": 8721 + "loss": 0.3946074267790835, + "n_samples": 6783 }, { - "loss": 0.4327090857889029, - "n_samples": 8721 + "loss": 0.3898155013755344, + "n_samples": 6783 }, { - "loss": 0.4249273364101852, - "n_samples": 8721 + "loss": 0.37861797005733383, + "n_samples": 6783 }, { - "loss": 0.4216959138704459, - "n_samples": 8721 + "loss": 0.3775682858392304, + "n_samples": 6783 }, { - "loss": 0.416526201182502, - "n_samples": 8721 + "loss": 0.3800349080262064, + "n_samples": 6783 }, { - "loss": 0.40368679039741573, - "n_samples": 8721 + "loss": 0.36302345173031675, + "n_samples": 6783 }, { - "loss": 0.4051084730032182, - "n_samples": 8721 + "loss": 0.3429561740842766, + "n_samples": 6783 }, { - "loss": 0.38971701020385785, - "n_samples": 8721 + "loss": 0.3445638883004898, + "n_samples": 6783 }, { - "loss": 0.39155546386038786, - "n_samples": 8721 + "loss": 0.318970229203733, + "n_samples": 6783 }, { - "loss": 0.37976963541784114, - "n_samples": 8721 + "loss": 0.30179278279904437, + "n_samples": 6783 }, { - "loss": 0.36484339719805037, - "n_samples": 8721 + "loss": 0.2887343142006437, + "n_samples": 6783 }, { - "loss": 0.36232607571196496, - "n_samples": 8721 - }, - { - "loss": 0.3345973272380199, - "n_samples": 8721 - }, - { - "loss": 0.31767916518768957, - "n_samples": 8721 - }, - { - "loss": 0.32065429246052457, - "n_samples": 8721 - }, - { - "loss": 0.3171297926146043, - "n_samples": 8721 - }, - { - "loss": 0.3122120894173009, - "n_samples": 8721 - }, - { - "loss": 0.3135035038404461, - "n_samples": 8721 - }, - { - "loss": 0.2987745178222875, - "n_samples": 8721 - }, - { - "loss": 0.2914867957853393, - "n_samples": 8721 - }, - { - "loss": 0.2983839795507705, - "n_samples": 8721 - }, - { - "loss": 0.2826709597875678, - "n_samples": 8721 - }, - { - "loss": 0.2731766632569382, - "n_samples": 8721 - }, - { - "loss": 0.27726896305742266, - "n_samples": 8721 - }, - { - "loss": 0.27864557847067956, - "n_samples": 8721 + "loss": 0.29240367556855545, + "n_samples": 6783 } ], "val": [ { - "loss": 0.7601077516012517, - "n_samples": 969 + "loss": 0.7350345371841441, + "n_samples": 2907 }, { - "loss": 0.7119935319611901, - "n_samples": 969 + "loss": 0.7165568811318536, + "n_samples": 2907 }, { - "loss": 0.6461842978148269, - "n_samples": 969 + "loss": 0.7251406249862214, + "n_samples": 2907 }, { - "loss": 0.7006978391063226, - "n_samples": 969 + "loss": 0.6836505264587159, + "n_samples": 2907 }, { - "loss": 0.6533874032943979, - "n_samples": 969 + "loss": 0.6747132955771933, + "n_samples": 2907 }, { - "loss": 0.6413641451743611, - "n_samples": 969 + "loss": 0.6691136244936912, + "n_samples": 2907 }, { - "loss": 0.6168395132979742, - "n_samples": 969 + "loss": 0.6337480902323249, + "n_samples": 2907 }, { - "loss": 0.6095251602162025, - "n_samples": 969 + "loss": 0.6600317959527934, + "n_samples": 2907 }, { - "loss": 0.5887809592626905, - "n_samples": 969 + "loss": 0.6439923948855346, + "n_samples": 2907 }, { - "loss": 0.5655298325376368, - "n_samples": 969 + "loss": 0.643800035575267, + "n_samples": 2907 }, { - "loss": 0.5809201743872788, - "n_samples": 969 + "loss": 0.6181512585221839, + "n_samples": 2907 }, { - "loss": 0.5897585974912033, - "n_samples": 969 + "loss": 0.6442458634939151, + "n_samples": 2907 }, { - "loss": 0.5732012489662573, - "n_samples": 969 + "loss": 0.6344759362359862, + "n_samples": 2907 }, { - "loss": 0.5607388911786094, - "n_samples": 969 + "loss": 0.6501405371457472, + "n_samples": 2907 }, { - "loss": 0.5717580675371414, - "n_samples": 969 + "loss": 0.6098835162990152, + "n_samples": 2907 }, { - "loss": 0.5553950037657291, - "n_samples": 969 + "loss": 0.6366627322138894, + "n_samples": 2907 }, { - "loss": 0.5778171792857049, - "n_samples": 969 + "loss": 0.6171610150646417, + "n_samples": 2907 }, { - "loss": 0.5602665468127734, - "n_samples": 969 + "loss": 0.6358801012273748, + "n_samples": 2907 }, { - "loss": 0.5475307451359259, - "n_samples": 969 + "loss": 0.6239976831059871, + "n_samples": 2907 }, { - "loss": 0.551515599314827, - "n_samples": 969 + "loss": 0.6683828232827201, + "n_samples": 2907 }, { - "loss": 0.5755438121541243, - "n_samples": 969 + "loss": 0.6655785786478143, + "n_samples": 2907 }, { - "loss": 0.5798238261811381, - "n_samples": 969 + "loss": 0.6152775046503088, + "n_samples": 2907 }, { - "loss": 0.5739961433828923, - "n_samples": 969 + "loss": 0.6202247662153858, + "n_samples": 2907 }, { - "loss": 0.5742599932540312, - "n_samples": 969 + "loss": 0.648199727435189, + "n_samples": 2907 }, { - "loss": 0.5834948123885382, - "n_samples": 969 - }, - { - "loss": 0.554078846570139, - "n_samples": 969 - }, - { - "loss": 0.5714933996322354, - "n_samples": 969 - }, - { - "loss": 0.5384107524350331, - "n_samples": 969 - }, - { - "loss": 0.570854394451568, - "n_samples": 969 - }, - { - "loss": 0.5767292551642478, - "n_samples": 969 - }, - { - "loss": 0.5660079547556808, - "n_samples": 969 - }, - { - "loss": 0.5608972411514312, - "n_samples": 969 - }, - { - "loss": 0.5620947442987263, - "n_samples": 969 - }, - { - "loss": 0.5706970894361305, - "n_samples": 969 - }, - { - "loss": 0.5702376298690974, - "n_samples": 969 - }, - { - "loss": 0.5758474825259579, - "n_samples": 969 - }, - { - "loss": 0.5673816067284844, - "n_samples": 969 - }, - { - "loss": 0.5671441179879925, - "n_samples": 969 + "loss": 0.6473217075085124, + "n_samples": 2907 } ] } \ No newline at end of file diff --git a/scripts/process_data_cv.py b/scripts/process_external_cv.py similarity index 100% rename from scripts/process_data_cv.py rename to scripts/process_external_cv.py