整合与baseline比对的相关脚本

This commit is contained in:
RYDE-WORK 2026-02-26 18:31:22 +08:00
parent c7c33e3f48
commit 00f51f37f0
4 changed files with 32 additions and 36 deletions

View File

@ -67,7 +67,7 @@ logs/
# Models (will be mounted as volume or copied explicitly)
# Note: models/final/ is copied in Dockerfile
models/finetune_cv/
models/pretrain_cv/
models/benchmark/
models/mpnn/
models/*.pt
models/*.json

View File

@ -76,10 +76,10 @@ data_final: requirements
data_pretrain: requirements
$(PYTHON_INTERPRETER) scripts/process_external.py
## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv)
.PHONY: data_pretrain_cv
data_pretrain_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_external_cv.py
## Process baseline CV data for benchmark (external/all_amine_split_for_LiON -> processed/benchmark)
.PHONY: data_benchmark
data_benchmark: requirements
$(PYTHON_INTERPRETER) scripts/process_benchmark_data.py
## Process internal data with CV splitting (interim -> processed/cv)
## Use SCAFFOLD_SPLIT=1 to enable amine-based scaffold splitting (default: random shuffle)
@ -96,10 +96,11 @@ data_cv: requirements
pretrain: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG) $(DEVICE_FLAG)
## Pretrain with cross-validation (5-fold)
.PHONY: pretrain_cv
pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG)
## Benchmark on baseline CV data: 5-fold train + test (delivery only)
.PHONY: benchmark
benchmark: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.benchmark main $(MPNN_FLAG) $(DEVICE_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.benchmark test $(DEVICE_FLAG)
## Train model (multi-task, from scratch)
.PHONY: train
@ -140,11 +141,6 @@ finetune_cv: requirements
test_pretrain: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_pretrain_cv
test_pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG)
## Evaluate CV finetuned models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_cv
test_cv: requirements

View File

@ -1,4 +1,4 @@
"""基于 Cross-Validation 的预训练脚本"""
"""Benchmark 脚本:在 baseline 论文公开的 CV 划分上评估模型(仅 delivery 任务)"""
import json
from pathlib import Path
@ -232,7 +232,7 @@ def train_fold(
plot_loss_curves(
history=history,
output_path=loss_plot_path,
title=f"Pretrain Fold {fold_idx} Loss Curves",
title=f"Benchmark Fold {fold_idx} Loss Curves",
)
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
@ -281,8 +281,8 @@ def create_model(
@app.command()
def main(
data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
output_dir: Path = MODELS_DIR / "pretrain_cv",
data_dir: Path = PROCESSED_DATA_DIR / "benchmark",
output_dir: Path = MODELS_DIR / "benchmark",
# 模型参数
d_model: int = 256,
num_heads: int = 8,
@ -305,7 +305,7 @@ def main(
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
基于 5-fold Cross-Validation 预训练 LNP 模型 delivery 任务
baseline 论文公开的 5-fold CV 划分上训练模型 delivery 任务
每个 fold 单独训练一个模型保存到 output_dir/fold_x/model.pt
使用 --use-mpnn 启用 MPNN encoder
@ -332,7 +332,7 @@ def main(
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
logger.info("Please run 'make data_pretrain_cv' first to process CV data.")
logger.info("Please run 'make data_benchmark' first to process benchmark CV data.")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
@ -430,7 +430,7 @@ def main(
# 汇总结果
logger.info("\n" + "=" * 60)
logger.info("CROSS-VALIDATION TRAINING COMPLETE")
logger.info("BENCHMARK CV TRAINING COMPLETE")
logger.info("=" * 60)
val_losses = [r["best_val_loss"] for r in fold_results]
@ -474,16 +474,16 @@ def main(
@app.command()
def test(
data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
model_dir: Path = MODELS_DIR / "pretrain_cv",
output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json",
data_dir: Path = PROCESSED_DATA_DIR / "benchmark",
model_dir: Path = MODELS_DIR / "benchmark",
output_path: Path = MODELS_DIR / "benchmark" / "test_results.json",
batch_size: int = 64,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
测试集上评估 CV 预训练模型
baseline CV 测试集上评估 benchmark 模型
使用每个 fold 的模型在对应的测试集上评估
使用每个 fold 训练的模型在对应的测试集上评估汇总跨 fold 结果
"""
logger.info(f"Using device: {device}")
device = torch.device(device)
@ -609,7 +609,7 @@ def test(
r2s = [r["r2"] for r in fold_results]
logger.info("\n" + "=" * 60)
logger.info("CV TEST EVALUATION RESULTS")
logger.info("BENCHMARK TEST EVALUATION RESULTS")
logger.info("=" * 60)
logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]")

View File

@ -1,4 +1,4 @@
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
"""处理 benchmark 数据脚本:将 baseline 论文公开的 CV splits 转换为模型所需的 parquet 格式"""
from pathlib import Path
from typing import Dict, List, Tuple
@ -151,18 +151,18 @@ def get_feature_columns() -> List[str]:
@app.command()
def main(
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
output_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
output_dir: Path = PROCESSED_DATA_DIR / "benchmark",
n_folds: int = 5,
):
"""
处理 cross-validation 数据生成模型所需的 parquet 文件
处理 baseline 论文公开的 CV 划分数据生成 benchmark 所需的 parquet 文件
输出结构:
- processed/pretrain_cv/fold_0/train.parquet
- processed/pretrain_cv/fold_0/valid.parquet
- processed/pretrain_cv/fold_0/test.parquet
- processed/pretrain_cv/fold_1/...
- processed/pretrain_cv/feature_columns.txt
- processed/benchmark/fold_0/train.parquet
- processed/benchmark/fold_0/valid.parquet
- processed/benchmark/fold_0/test.parquet
- processed/benchmark/fold_1/...
- processed/benchmark/feature_columns.txt
"""
logger.info(f"Processing CV data from {data_dir}")
@ -223,7 +223,7 @@ def main(
logger.success(f"Saved feature columns to {cols_path}")
logger.info("\n" + "=" * 60)
logger.info("CV DATA PROCESSING COMPLETE")
logger.info("BENCHMARK DATA PROCESSING COMPLETE")
logger.info("=" * 60)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Number of folds: {len(cv_dirs)}")