diff --git a/Makefile b/Makefile index d853692..160eb95 100644 --- a/Makefile +++ b/Makefile @@ -78,10 +78,13 @@ data_pretrain: requirements data_pretrain_cv: requirements $(PYTHON_INTERPRETER) scripts/process_external_cv.py -## Process internal data with amine-based CV splitting (interim -> processed/cv) +## Process internal data with CV splitting (interim -> processed/cv) +## Use SCAFFOLD_SPLIT=1 to enable amine-based scaffold splitting (default: random shuffle) +SCAFFOLD_SPLIT_FLAG = $(if $(filter 1,$(SCAFFOLD_SPLIT)),--scaffold-split,) + .PHONY: data_cv data_cv: requirements - $(PYTHON_INTERPRETER) scripts/process_data_cv.py + $(PYTHON_INTERPRETER) scripts/process_data_cv.py $(SCAFFOLD_SPLIT_FLAG) # MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder # 例如:make pretrain USE_MPNN=1 diff --git a/data/processed/cv/fold_0/test.parquet b/data/processed/cv/fold_0/test.parquet index 7799d90..a5f5215 100644 Binary files a/data/processed/cv/fold_0/test.parquet and b/data/processed/cv/fold_0/test.parquet differ diff --git a/data/processed/cv/fold_0/train.parquet b/data/processed/cv/fold_0/train.parquet index e753b68..e3c307e 100644 Binary files a/data/processed/cv/fold_0/train.parquet and b/data/processed/cv/fold_0/train.parquet differ diff --git a/data/processed/cv/fold_0/val.parquet b/data/processed/cv/fold_0/val.parquet index 1f14c35..7905976 100644 Binary files a/data/processed/cv/fold_0/val.parquet and b/data/processed/cv/fold_0/val.parquet differ diff --git a/data/processed/cv/fold_1/test.parquet b/data/processed/cv/fold_1/test.parquet index 30b25a8..14f37b7 100644 Binary files a/data/processed/cv/fold_1/test.parquet and b/data/processed/cv/fold_1/test.parquet differ diff --git a/data/processed/cv/fold_1/train.parquet b/data/processed/cv/fold_1/train.parquet index c096898..1cfa625 100644 Binary files a/data/processed/cv/fold_1/train.parquet and b/data/processed/cv/fold_1/train.parquet differ diff --git a/data/processed/cv/fold_1/val.parquet b/data/processed/cv/fold_1/val.parquet index 7799d90..a5f5215 100644 Binary files a/data/processed/cv/fold_1/val.parquet and b/data/processed/cv/fold_1/val.parquet differ diff --git a/data/processed/cv/fold_2/test.parquet b/data/processed/cv/fold_2/test.parquet index b3bb10b..5605067 100644 Binary files a/data/processed/cv/fold_2/test.parquet and b/data/processed/cv/fold_2/test.parquet differ diff --git a/data/processed/cv/fold_2/train.parquet b/data/processed/cv/fold_2/train.parquet index bef51fc..31e66b6 100644 Binary files a/data/processed/cv/fold_2/train.parquet and b/data/processed/cv/fold_2/train.parquet differ diff --git a/data/processed/cv/fold_2/val.parquet b/data/processed/cv/fold_2/val.parquet index 30b25a8..14f37b7 100644 Binary files a/data/processed/cv/fold_2/val.parquet and b/data/processed/cv/fold_2/val.parquet differ diff --git a/data/processed/cv/fold_3/test.parquet b/data/processed/cv/fold_3/test.parquet index 93357c0..ed3b26d 100644 Binary files a/data/processed/cv/fold_3/test.parquet and b/data/processed/cv/fold_3/test.parquet differ diff --git a/data/processed/cv/fold_3/train.parquet b/data/processed/cv/fold_3/train.parquet index b215888..ec51f26 100644 Binary files a/data/processed/cv/fold_3/train.parquet and b/data/processed/cv/fold_3/train.parquet differ diff --git a/data/processed/cv/fold_3/val.parquet b/data/processed/cv/fold_3/val.parquet index b3bb10b..5605067 100644 Binary files a/data/processed/cv/fold_3/val.parquet and b/data/processed/cv/fold_3/val.parquet differ diff --git a/data/processed/cv/fold_4/test.parquet b/data/processed/cv/fold_4/test.parquet index 1f14c35..7905976 100644 Binary files a/data/processed/cv/fold_4/test.parquet and b/data/processed/cv/fold_4/test.parquet differ diff --git a/data/processed/cv/fold_4/train.parquet b/data/processed/cv/fold_4/train.parquet index 179480e..c802f21 100644 Binary files a/data/processed/cv/fold_4/train.parquet and b/data/processed/cv/fold_4/train.parquet differ diff --git a/data/processed/cv/fold_4/val.parquet b/data/processed/cv/fold_4/val.parquet index 93357c0..ed3b26d 100644 Binary files a/data/processed/cv/fold_4/val.parquet and b/data/processed/cv/fold_4/val.parquet differ diff --git a/scripts/process_data_cv.py b/scripts/process_data_cv.py index 7be689e..1b02404 100644 --- a/scripts/process_data_cv.py +++ b/scripts/process_data_cv.py @@ -1,4 +1,4 @@ -"""内部数据 Cross-Validation 划分脚本:基于 Amine 的分组划分""" +"""内部数据 Cross-Validation 划分脚本:支持随机划分或基于 Amine 的分组划分""" from pathlib import Path from typing import List @@ -26,6 +26,71 @@ from lnp_ml.dataset import ( app = typer.Typer() +def random_cv_split( + df: pd.DataFrame, + n_folds: int = 5, + seed: int = 42, +) -> List[dict]: + """ + 随机 shuffle 进行 Cross-Validation 划分。 + + 步骤: + 1. 打乱所有样本 + 2. 将样本分成 n_folds 个容器 + 3. 对于每个 fold i: + - validation = container[i] + - test = container[(i+1) % n_folds] + - train = 其余所有 + + Args: + df: 输入 DataFrame + n_folds: 折数 + seed: 随机种子 + + Returns: + List of dicts,每个 dict 包含 train_df, val_df, test_df + """ + # 打乱所有样本 + df_shuffled = df.sample(frac=1, random_state=seed).reset_index(drop=True) + n_samples = len(df_shuffled) + + logger.info(f"Total {n_samples} samples for random CV split") + + # 将样本分成 n_folds 个容器 + indices = np.arange(n_samples) + containers = np.array_split(indices, n_folds) + + # 打印每个容器的大小 + for i, container in enumerate(containers): + logger.info(f" Container {i}: {len(container)} samples") + + # 生成每个 fold 的数据 + fold_splits = [] + for i in range(n_folds): + val_indices = containers[i] + test_indices = containers[(i + 1) % n_folds] + train_indices = np.concatenate([ + containers[j] for j in range(n_folds) + if j != i and j != (i + 1) % n_folds + ]) + + train_df = df_shuffled.iloc[train_indices].reset_index(drop=True) + val_df = df_shuffled.iloc[val_indices].reset_index(drop=True) + test_df = df_shuffled.iloc[test_indices].reset_index(drop=True) + + fold_splits.append({ + "train": train_df, + "val": val_df, + "test": test_df, + }) + + logger.info( + f"Fold {i}: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}" + ) + + return fold_splits + + def amine_based_cv_split( df: pd.DataFrame, n_folds: int = 5, @@ -106,12 +171,18 @@ def main( n_folds: int = 5, seed: int = 42, amine_col: str = "Amine", + scaffold_split: bool = typer.Option( + False, + "--scaffold-split", + help="使用基于 Amine 的 scaffold splitting(默认:随机 shuffle)", + ), ): """ - 基于 Amine 分组进行 Cross-Validation 数据划分。 + Cross-Validation 数据划分。 - 采用类似 scaffold splitting 的思路,将相同 Amine 的数据放在同一组, - 确保训练集和测试集之间没有 Amine 泄露。 + 支持两种划分方式: + - 随机划分(默认):直接 shuffle 所有样本 + - Scaffold splitting(--scaffold-split):基于 Amine 分组,确保同一 Amine 的数据在同一组 划分比例约为 train:val:test ≈ 3:1:1 @@ -126,20 +197,19 @@ def main( df = pd.read_csv(input_path) logger.info(f"Loaded {len(df)} samples") - # 检查 amine 列是否存在 - if amine_col not in df.columns: - logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(df.columns)}") - raise typer.Exit(1) - # 处理数据(列对齐、one-hot 生成等) logger.info("Processing dataframe...") df = process_dataframe(df) - # 确保 Amine 列仍然存在(process_dataframe 可能不会保留它) - # 重新加载原始数据获取 Amine 列 - original_df = pd.read_csv(input_path) - if amine_col in original_df.columns and amine_col not in df.columns: - df[amine_col] = original_df[amine_col].values + # 如果使用 scaffold split,检查 amine 列是否存在 + if scaffold_split: + # 重新加载原始数据获取 Amine 列(process_dataframe 可能不会保留它) + original_df = pd.read_csv(input_path) + if amine_col not in original_df.columns: + logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(original_df.columns)}") + raise typer.Exit(1) + if amine_col not in df.columns: + df[amine_col] = original_df[amine_col].values # 定义要保留的列 phys_cols = get_phys_cols() @@ -162,8 +232,14 @@ def main( keep_cols = [c for c in keep_cols if c in df.columns] # 进行 CV 划分 - logger.info(f"\nPerforming {n_folds}-fold amine-based CV split (seed={seed})...") - fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col) + if scaffold_split: + logger.info(f"\nPerforming {n_folds}-fold amine-based scaffold CV split (seed={seed})...") + fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col) + split_method = f"Amine-based scaffold (column: {amine_col})" + else: + logger.info(f"\nPerforming {n_folds}-fold random CV split (seed={seed})...") + fold_splits = random_cv_split(df, n_folds=n_folds, seed=seed) + split_method = "Random shuffle" # 保存每个 fold output_dir.mkdir(parents=True, exist_ok=True) @@ -217,7 +293,7 @@ def main( logger.info("=" * 60) logger.info(f"Output directory: {output_dir}") logger.info(f"Number of folds: {n_folds}") - logger.info(f"Splitting method: Amine-based (column: {amine_col})") + logger.info(f"Splitting method: {split_method}") logger.info(f"Random seed: {seed}")