mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 01:27:00 +08:00
Add random CV split
This commit is contained in:
parent
ac4246c2b7
commit
871afc5988
7
Makefile
7
Makefile
@ -78,10 +78,13 @@ data_pretrain: requirements
|
||||
data_pretrain_cv: requirements
|
||||
$(PYTHON_INTERPRETER) scripts/process_external_cv.py
|
||||
|
||||
## Process internal data with amine-based CV splitting (interim -> processed/cv)
|
||||
## Process internal data with CV splitting (interim -> processed/cv)
|
||||
## Use SCAFFOLD_SPLIT=1 to enable amine-based scaffold splitting (default: random shuffle)
|
||||
SCAFFOLD_SPLIT_FLAG = $(if $(filter 1,$(SCAFFOLD_SPLIT)),--scaffold-split,)
|
||||
|
||||
.PHONY: data_cv
|
||||
data_cv: requirements
|
||||
$(PYTHON_INTERPRETER) scripts/process_data_cv.py
|
||||
$(PYTHON_INTERPRETER) scripts/process_data_cv.py $(SCAFFOLD_SPLIT_FLAG)
|
||||
|
||||
# MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder
|
||||
# 例如:make pretrain USE_MPNN=1
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,4 +1,4 @@
|
||||
"""内部数据 Cross-Validation 划分脚本:基于 Amine 的分组划分"""
|
||||
"""内部数据 Cross-Validation 划分脚本:支持随机划分或基于 Amine 的分组划分"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
@ -26,6 +26,71 @@ from lnp_ml.dataset import (
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def random_cv_split(
|
||||
df: pd.DataFrame,
|
||||
n_folds: int = 5,
|
||||
seed: int = 42,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
随机 shuffle 进行 Cross-Validation 划分。
|
||||
|
||||
步骤:
|
||||
1. 打乱所有样本
|
||||
2. 将样本分成 n_folds 个容器
|
||||
3. 对于每个 fold i:
|
||||
- validation = container[i]
|
||||
- test = container[(i+1) % n_folds]
|
||||
- train = 其余所有
|
||||
|
||||
Args:
|
||||
df: 输入 DataFrame
|
||||
n_folds: 折数
|
||||
seed: 随机种子
|
||||
|
||||
Returns:
|
||||
List of dicts,每个 dict 包含 train_df, val_df, test_df
|
||||
"""
|
||||
# 打乱所有样本
|
||||
df_shuffled = df.sample(frac=1, random_state=seed).reset_index(drop=True)
|
||||
n_samples = len(df_shuffled)
|
||||
|
||||
logger.info(f"Total {n_samples} samples for random CV split")
|
||||
|
||||
# 将样本分成 n_folds 个容器
|
||||
indices = np.arange(n_samples)
|
||||
containers = np.array_split(indices, n_folds)
|
||||
|
||||
# 打印每个容器的大小
|
||||
for i, container in enumerate(containers):
|
||||
logger.info(f" Container {i}: {len(container)} samples")
|
||||
|
||||
# 生成每个 fold 的数据
|
||||
fold_splits = []
|
||||
for i in range(n_folds):
|
||||
val_indices = containers[i]
|
||||
test_indices = containers[(i + 1) % n_folds]
|
||||
train_indices = np.concatenate([
|
||||
containers[j] for j in range(n_folds)
|
||||
if j != i and j != (i + 1) % n_folds
|
||||
])
|
||||
|
||||
train_df = df_shuffled.iloc[train_indices].reset_index(drop=True)
|
||||
val_df = df_shuffled.iloc[val_indices].reset_index(drop=True)
|
||||
test_df = df_shuffled.iloc[test_indices].reset_index(drop=True)
|
||||
|
||||
fold_splits.append({
|
||||
"train": train_df,
|
||||
"val": val_df,
|
||||
"test": test_df,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"Fold {i}: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}"
|
||||
)
|
||||
|
||||
return fold_splits
|
||||
|
||||
|
||||
def amine_based_cv_split(
|
||||
df: pd.DataFrame,
|
||||
n_folds: int = 5,
|
||||
@ -106,12 +171,18 @@ def main(
|
||||
n_folds: int = 5,
|
||||
seed: int = 42,
|
||||
amine_col: str = "Amine",
|
||||
scaffold_split: bool = typer.Option(
|
||||
False,
|
||||
"--scaffold-split",
|
||||
help="使用基于 Amine 的 scaffold splitting(默认:随机 shuffle)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
基于 Amine 分组进行 Cross-Validation 数据划分。
|
||||
Cross-Validation 数据划分。
|
||||
|
||||
采用类似 scaffold splitting 的思路,将相同 Amine 的数据放在同一组,
|
||||
确保训练集和测试集之间没有 Amine 泄露。
|
||||
支持两种划分方式:
|
||||
- 随机划分(默认):直接 shuffle 所有样本
|
||||
- Scaffold splitting(--scaffold-split):基于 Amine 分组,确保同一 Amine 的数据在同一组
|
||||
|
||||
划分比例约为 train:val:test ≈ 3:1:1
|
||||
|
||||
@ -126,20 +197,19 @@ def main(
|
||||
df = pd.read_csv(input_path)
|
||||
logger.info(f"Loaded {len(df)} samples")
|
||||
|
||||
# 检查 amine 列是否存在
|
||||
if amine_col not in df.columns:
|
||||
logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(df.columns)}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# 处理数据(列对齐、one-hot 生成等)
|
||||
logger.info("Processing dataframe...")
|
||||
df = process_dataframe(df)
|
||||
|
||||
# 确保 Amine 列仍然存在(process_dataframe 可能不会保留它)
|
||||
# 重新加载原始数据获取 Amine 列
|
||||
original_df = pd.read_csv(input_path)
|
||||
if amine_col in original_df.columns and amine_col not in df.columns:
|
||||
df[amine_col] = original_df[amine_col].values
|
||||
# 如果使用 scaffold split,检查 amine 列是否存在
|
||||
if scaffold_split:
|
||||
# 重新加载原始数据获取 Amine 列(process_dataframe 可能不会保留它)
|
||||
original_df = pd.read_csv(input_path)
|
||||
if amine_col not in original_df.columns:
|
||||
logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(original_df.columns)}")
|
||||
raise typer.Exit(1)
|
||||
if amine_col not in df.columns:
|
||||
df[amine_col] = original_df[amine_col].values
|
||||
|
||||
# 定义要保留的列
|
||||
phys_cols = get_phys_cols()
|
||||
@ -162,8 +232,14 @@ def main(
|
||||
keep_cols = [c for c in keep_cols if c in df.columns]
|
||||
|
||||
# 进行 CV 划分
|
||||
logger.info(f"\nPerforming {n_folds}-fold amine-based CV split (seed={seed})...")
|
||||
fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col)
|
||||
if scaffold_split:
|
||||
logger.info(f"\nPerforming {n_folds}-fold amine-based scaffold CV split (seed={seed})...")
|
||||
fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col)
|
||||
split_method = f"Amine-based scaffold (column: {amine_col})"
|
||||
else:
|
||||
logger.info(f"\nPerforming {n_folds}-fold random CV split (seed={seed})...")
|
||||
fold_splits = random_cv_split(df, n_folds=n_folds, seed=seed)
|
||||
split_method = "Random shuffle"
|
||||
|
||||
# 保存每个 fold
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -217,7 +293,7 @@ def main(
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Output directory: {output_dir}")
|
||||
logger.info(f"Number of folds: {n_folds}")
|
||||
logger.info(f"Splitting method: Amine-based (column: {amine_col})")
|
||||
logger.info(f"Splitting method: {split_method}")
|
||||
logger.info(f"Random seed: {seed}")
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user