lnp_ml/scripts/process_data_cv.py
RYDE-WORK 039be54c5a ...
2026-01-22 01:01:29 +08:00

227 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""内部数据 Cross-Validation 划分脚本:基于 Amine 的分组划分"""
from pathlib import Path
from typing import List
import numpy as np
import pandas as pd
import typer
from loguru import logger
from lnp_ml.config import INTERIM_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import (
process_dataframe,
SMILES_COL,
COMP_COLS,
HELP_COLS,
TARGET_REGRESSION,
TARGET_CLASSIFICATION_PDI,
TARGET_CLASSIFICATION_EE,
TARGET_TOXIC,
TARGET_BIODIST,
get_phys_cols,
get_exp_cols,
)
app = typer.Typer()
def amine_based_cv_split(
df: pd.DataFrame,
n_folds: int = 5,
seed: int = 42,
amine_col: str = "Amine",
) -> List[dict]:
"""
基于 Amine 列进行 Cross-Validation 划分。
步骤:
1. 按 amine_col 分组
2. 打乱分组顺序
3. 将分组 round-robin 分配到 n_folds 个容器
4. 对于每个 fold i
- validation = container[i]
- test = container[(i+1) % n_folds]
- train = 其余所有
Args:
df: 输入 DataFrame
n_folds: 折数
seed: 随机种子
amine_col: 用于分组的列名
Returns:
List of dicts每个 dict 包含 train_df, val_df, test_df
"""
# 获取唯一的 amine 并打乱
unique_amines = df[amine_col].unique()
rng = np.random.RandomState(seed)
rng.shuffle(unique_amines)
logger.info(f"Found {len(unique_amines)} unique amines")
# Round-robin 分配到 n_folds 个容器
containers = [[] for _ in range(n_folds)]
for i, amine in enumerate(unique_amines):
containers[i % n_folds].append(amine)
# 打印每个容器的大小
for i, container in enumerate(containers):
container_samples = df[df[amine_col].isin(container)]
logger.info(f" Container {i}: {len(container)} amines, {len(container_samples)} samples")
# 生成每个 fold 的数据
fold_splits = []
for i in range(n_folds):
val_amines = set(containers[i])
test_amines = set(containers[(i + 1) % n_folds])
train_amines = set()
for j in range(n_folds):
if j != i and j != (i + 1) % n_folds:
train_amines.update(containers[j])
train_df = df[df[amine_col].isin(train_amines)].reset_index(drop=True)
val_df = df[df[amine_col].isin(val_amines)].reset_index(drop=True)
test_df = df[df[amine_col].isin(test_amines)].reset_index(drop=True)
fold_splits.append({
"train": train_df,
"val": val_df,
"test": test_df,
})
logger.info(
f"Fold {i}: train={len(train_df)} ({len(train_amines)} amines), "
f"val={len(val_df)} ({len(val_amines)} amines), "
f"test={len(test_df)} ({len(test_amines)} amines)"
)
return fold_splits
@app.command()
def main(
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
output_dir: Path = PROCESSED_DATA_DIR / "cv",
n_folds: int = 5,
seed: int = 42,
amine_col: str = "Amine",
):
"""
基于 Amine 分组进行 Cross-Validation 数据划分。
采用类似 scaffold splitting 的思路,将相同 Amine 的数据放在同一组,
确保训练集和测试集之间没有 Amine 泄露。
划分比例约为 train:val:test ≈ 3:1:1
输出结构:
- processed/cv/fold_0/train.parquet
- processed/cv/fold_0/val.parquet
- processed/cv/fold_0/test.parquet
- processed/cv/fold_1/...
- processed/cv/feature_columns.txt
"""
logger.info(f"Loading data from {input_path}")
df = pd.read_csv(input_path)
logger.info(f"Loaded {len(df)} samples")
# 检查 amine 列是否存在
if amine_col not in df.columns:
logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(df.columns)}")
raise typer.Exit(1)
# 处理数据列对齐、one-hot 生成等)
logger.info("Processing dataframe...")
df = process_dataframe(df)
# 确保 Amine 列仍然存在process_dataframe 可能不会保留它)
# 重新加载原始数据获取 Amine 列
original_df = pd.read_csv(input_path)
if amine_col in original_df.columns and amine_col not in df.columns:
df[amine_col] = original_df[amine_col].values
# 定义要保留的列
phys_cols = get_phys_cols()
exp_cols = get_exp_cols()
keep_cols = (
[SMILES_COL]
+ COMP_COLS
+ phys_cols
+ HELP_COLS
+ exp_cols
+ TARGET_REGRESSION
+ TARGET_CLASSIFICATION_PDI
+ TARGET_CLASSIFICATION_EE
+ [TARGET_TOXIC]
+ TARGET_BIODIST
)
# 只保留存在的列
keep_cols = [c for c in keep_cols if c in df.columns]
# 进行 CV 划分
logger.info(f"\nPerforming {n_folds}-fold amine-based CV split (seed={seed})...")
fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col)
# 保存每个 fold
output_dir.mkdir(parents=True, exist_ok=True)
for i, split in enumerate(fold_splits):
fold_dir = output_dir / f"fold_{i}"
fold_dir.mkdir(parents=True, exist_ok=True)
# 只保留需要的列
train_df = split["train"][keep_cols].reset_index(drop=True)
val_df = split["val"][keep_cols].reset_index(drop=True)
test_df = split["test"][keep_cols].reset_index(drop=True)
# 保存
train_df.to_parquet(fold_dir / "train.parquet", index=False)
val_df.to_parquet(fold_dir / "val.parquet", index=False)
test_df.to_parquet(fold_dir / "test.parquet", index=False)
logger.success(f"Saved fold {i} to {fold_dir}")
# 保存列名配置
config_path = output_dir / "feature_columns.txt"
with open(config_path, "w") as f:
f.write("# Feature columns configuration\n\n")
f.write(f"# SMILES\n{SMILES_COL}\n\n")
f.write(f"# comp token [{len(COMP_COLS)}]\n")
f.write("\n".join(COMP_COLS) + "\n\n")
f.write(f"# phys token [{len(phys_cols)}]\n")
f.write("\n".join(phys_cols) + "\n\n")
f.write(f"# help token [{len(HELP_COLS)}]\n")
f.write("\n".join(HELP_COLS) + "\n\n")
f.write(f"# exp token [{len(exp_cols)}]\n")
f.write("\n".join(exp_cols) + "\n\n")
f.write("# Targets\n")
f.write("## Regression\n")
f.write("\n".join(TARGET_REGRESSION) + "\n")
f.write("## PDI classification\n")
f.write("\n".join(TARGET_CLASSIFICATION_PDI) + "\n")
f.write("## EE classification\n")
f.write("\n".join(TARGET_CLASSIFICATION_EE) + "\n")
f.write("## Toxic\n")
f.write(f"{TARGET_TOXIC}\n")
f.write("## Biodistribution\n")
f.write("\n".join(TARGET_BIODIST) + "\n")
logger.success(f"Saved feature config to {config_path}")
# 打印汇总
logger.info("\n" + "=" * 60)
logger.info("CV DATA PROCESSING COMPLETE")
logger.info("=" * 60)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Number of folds: {n_folds}")
logger.info(f"Splitting method: Amine-based (column: {amine_col})")
logger.info(f"Random seed: {seed}")
if __name__ == "__main__":
app()