This commit is contained in:
RYDE-WORK 2026-01-21 19:35:55 +08:00
parent 6773929ea2
commit a2bfb26dfc
29 changed files with 1915 additions and 2053 deletions

View File

@ -73,6 +73,11 @@ data: requirements
data_pretrain: requirements data_pretrain: requirements
$(PYTHON_INTERPRETER) scripts/process_external.py $(PYTHON_INTERPRETER) scripts/process_external.py
## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv)
.PHONY: data_cv
data_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_data_cv.py
# MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder # MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder
# 例如make pretrain USE_MPNN=1 # 例如make pretrain USE_MPNN=1
MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,) MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,)
@ -91,6 +96,16 @@ pretrain: requirements
test_pretrain: requirements test_pretrain: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG) $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG)
## Pretrain with cross-validation (5-fold)
.PHONY: pretrain_cv
pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG)
## Evaluate CV pretrain models on test sets
.PHONY: test_cv
test_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(MPNN_FLAG)
## Train model (multi-task, from scratch) ## Train model (multi-task, from scratch)
.PHONY: train .PHONY: train
train: requirements train: requirements

View File

@ -0,0 +1,55 @@
smiles
Cationic_Lipid_to_mRNA_weight_ratio
Cationic_Lipid_Mol_Ratio
Phospholipid_Mol_Ratio
Cholesterol_Mol_Ratio
PEG_Lipid_Mol_Ratio
Purity_Pure
Purity_Crude
Mix_type_Microfluidic
Mix_type_Pipetting
Cargo_type_mRNA
Cargo_type_pDNA
Cargo_type_siRNA
Target_or_delivered_gene_FFL
Target_or_delivered_gene_Peptide_barcode
Target_or_delivered_gene_hEPO
Target_or_delivered_gene_FVII
Target_or_delivered_gene_GFP
Helper_lipid_ID_DOPE
Helper_lipid_ID_DOTAP
Helper_lipid_ID_DSPC
Helper_lipid_ID_MDOA
Model_type_A549
Model_type_BDMC
Model_type_BMDM
Model_type_HBEC_ALI
Model_type_HEK293T
Model_type_HeLa
Model_type_IGROV1
Model_type_Mouse
Model_type_RAW264p7
Delivery_target_body
Delivery_target_dendritic_cell
Delivery_target_generic_cell
Delivery_target_liver
Delivery_target_lung
Delivery_target_lung_epithelium
Delivery_target_macrophage
Delivery_target_muscle
Delivery_target_spleen
Route_of_administration_in_vitro
Route_of_administration_intramuscular
Route_of_administration_intratracheal
Route_of_administration_intravenous
Batch_or_individual_or_barcoded_Barcoded
Batch_or_individual_or_barcoded_Individual
Value_name_log_luminescence
Value_name_luminescence
Value_name_FFL_silencing
Value_name_Peptide_abundance
Value_name_hEPO
Value_name_FVII_silencing
Value_name_GFP_delivery
Value_name_Discretized_luminescence
quantified_delivery

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,639 @@
"""基于 Cross-Validation 的预训练脚本"""
import json
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from loguru import logger
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, r2_score
import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
"""
自动查找 MPNN ensemble model.pt 文件
"""
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
if not model_paths:
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
return [str(p) for p in model_paths]
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
app = typer.Typer()
class EarlyStopping:
"""早停机制"""
def __init__(self, patience: int = 10, min_delta: float = 0.0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float("inf")
def __call__(self, val_loss: float) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
return False
self.counter += 1
return self.counter >= self.patience
def warmup_cache(model: nn.Module, smiles_list: List[str], batch_size: int = 256) -> None:
"""预热 RDKit 特征缓存"""
unique_smiles = list(set(smiles_list))
logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...")
for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"):
batch = unique_smiles[i:i + batch_size]
model.rdkit_encoder(batch)
logger.success(f"Cache warmup complete. Cached {len(model.rdkit_encoder._cache)} SMILES.")
def train_epoch_delivery(
model: nn.Module,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int = 0,
) -> Dict[str, float]:
"""单个 epoch 的训练(仅 delivery 任务)"""
model.train()
total_loss = 0.0
n_samples = 0
pbar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
for batch in pbar:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = batch["targets"]["delivery"].to(device)
mask = batch["mask"]["delivery"].to(device)
optimizer.zero_grad()
pred = model.forward_delivery(smiles, tabular).squeeze(-1)
if mask.any():
loss = nn.functional.mse_loss(pred[mask], targets[mask])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item() * mask.sum().item()
n_samples += mask.sum().item()
pbar.set_postfix({"loss": total_loss / max(n_samples, 1)})
avg_loss = total_loss / max(n_samples, 1)
return {"loss": avg_loss, "n_samples": n_samples}
@torch.no_grad()
def validate_delivery(
model: nn.Module,
loader: DataLoader,
device: torch.device,
) -> Dict[str, float]:
"""验证(仅 delivery 任务)"""
model.eval()
total_loss = 0.0
n_samples = 0
all_preds = []
all_targets = []
for batch in loader:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = batch["targets"]["delivery"].to(device)
mask = batch["mask"]["delivery"].to(device)
pred = model.forward_delivery(smiles, tabular).squeeze(-1)
if mask.any():
loss = nn.functional.mse_loss(pred[mask], targets[mask])
total_loss += loss.item() * mask.sum().item()
n_samples += mask.sum().item()
all_preds.extend(pred[mask].cpu().numpy().tolist())
all_targets.extend(targets[mask].cpu().numpy().tolist())
avg_loss = total_loss / max(n_samples, 1)
# 计算额外指标
metrics = {"loss": avg_loss, "n_samples": n_samples}
if len(all_preds) > 0:
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)
metrics["rmse"] = float(np.sqrt(mean_squared_error(all_targets, all_preds)))
metrics["r2"] = float(r2_score(all_targets, all_preds))
return metrics
def train_fold(
fold_idx: int,
train_loader: DataLoader,
val_loader: DataLoader,
model: nn.Module,
device: torch.device,
output_dir: Path,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 50,
patience: int = 10,
config: Optional[Dict] = None,
) -> Dict:
"""训练单个 fold"""
logger.info(f"\n{'='*60}")
logger.info(f"Training Fold {fold_idx}")
logger.info(f"{'='*60}")
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
early_stopping = EarlyStopping(patience=patience)
best_val_loss = float("inf")
best_state = None
history = []
for epoch in range(epochs):
train_metrics = train_epoch_delivery(model, train_loader, optimizer, device, epoch)
val_metrics = validate_delivery(model, val_loader, device)
current_lr = optimizer.param_groups[0]["lr"]
logger.info(
f"Fold {fold_idx} Epoch {epoch+1}/{epochs} | "
f"Train Loss: {train_metrics['loss']:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} | "
f"Val RMSE: {val_metrics.get('rmse', 0):.4f} | "
f"Val R²: {val_metrics.get('r2', 0):.4f} | "
f"LR: {current_lr:.2e}"
)
history.append({
"epoch": epoch + 1,
"train_loss": train_metrics["loss"],
"val_loss": val_metrics["loss"],
"val_rmse": val_metrics.get("rmse", 0),
"val_r2": val_metrics.get("r2", 0),
"lr": current_lr,
})
scheduler.step(val_metrics["loss"])
if val_metrics["loss"] < best_val_loss:
best_val_loss = val_metrics["loss"]
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
logger.info(f" -> New best val_loss: {best_val_loss:.4f}")
if early_stopping(val_metrics["loss"]):
logger.info(f"Early stopping at epoch {epoch+1}")
break
# 保存最佳模型
fold_output_dir = output_dir / f"fold_{fold_idx}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = fold_output_dir / "model.pt"
torch.save({
"model_state_dict": best_state,
"config": config,
"best_val_loss": best_val_loss,
"fold_idx": fold_idx,
}, checkpoint_path)
logger.success(f"Saved fold {fold_idx} model to {checkpoint_path}")
# 保存训练历史
history_path = fold_output_dir / "history.json"
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
return {
"fold_idx": fold_idx,
"best_val_loss": best_val_loss,
"best_val_rmse": history[-1]["val_rmse"] if history else 0,
"best_val_r2": history[-1]["val_r2"] if history else 0,
"epochs_trained": len(history),
}
def create_model(
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
use_mpnn: bool = False,
mpnn_ensemble_paths: Optional[List[str]] = None,
mpnn_device: str = "cpu",
) -> nn.Module:
"""创建模型实例"""
if use_mpnn:
return LNPModel(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_ensemble_paths=mpnn_ensemble_paths,
mpnn_device=mpnn_device,
)
else:
return LNPModelWithoutMPNN(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
)
@app.command()
def main(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
output_dir: Path = MODELS_DIR / "pretrain_cv",
# 模型参数
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
# MPNN 参数
use_mpnn: bool = False,
mpnn_checkpoint: Optional[str] = None,
mpnn_ensemble_paths: Optional[str] = None,
mpnn_device: str = "cpu",
# 训练参数
batch_size: int = 64,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 50,
patience: int = 10,
# 设备
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
基于 5-fold Cross-Validation 预训练 LNP 模型 delivery 任务
每个 fold 单独训练一个模型保存到 output_dir/fold_x/model.pt
使用 --use-mpnn 启用 MPNN encoder
"""
logger.info(f"Using device: {device}")
device = torch.device(device)
# 解析 MPNN 参数
mpnn_paths = None
if use_mpnn:
if mpnn_ensemble_paths:
mpnn_paths = mpnn_ensemble_paths.split(",")
logger.info(f"Using provided MPNN ensemble paths: {len(mpnn_paths)} models")
elif mpnn_checkpoint:
mpnn_paths = [mpnn_checkpoint]
logger.info(f"Using single MPNN checkpoint: {mpnn_checkpoint}")
else:
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
mpnn_paths = find_mpnn_ensemble_paths()
logger.info(f"Found {len(mpnn_paths)} MPNN models")
# 查找所有 fold 目录
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
logger.info("Please run 'make data_cv' first to process CV data.")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
output_dir.mkdir(parents=True, exist_ok=True)
# 模型配置
config = {
"d_model": d_model,
"num_heads": num_heads,
"n_attn_layers": n_attn_layers,
"fusion_strategy": fusion_strategy,
"head_hidden_dim": head_hidden_dim,
"dropout": dropout,
"use_mpnn": use_mpnn,
"mpnn_ensemble_paths": mpnn_paths,
"lr": lr,
"weight_decay": weight_decay,
"batch_size": batch_size,
"epochs": epochs,
"patience": patience,
}
# 保存配置
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
logger.info(f"Saved config to {config_path}")
# 训练每个 fold
fold_results = []
all_smiles = set()
# 先收集所有 SMILES 用于 cache warmup
for fold_dir in fold_dirs:
for split in ["train", "valid"]:
df = pd.read_parquet(fold_dir / f"{split}.parquet")
all_smiles.update(df["smiles"].tolist())
for fold_dir in fold_dirs:
fold_idx = int(fold_dir.name.split("_")[1])
# 加载数据
train_df = pd.read_parquet(fold_dir / "train.parquet")
val_df = pd.read_parquet(fold_dir / "valid.parquet")
logger.info(f"\nFold {fold_idx}: train={len(train_df)}, val={len(val_df)}")
# 创建 Dataset 和 DataLoader
train_dataset = ExternalDeliveryDataset(train_df)
val_dataset = ExternalDeliveryDataset(val_df)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 创建新模型(每个 fold 独立初始化)
model = create_model(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
use_mpnn=use_mpnn,
mpnn_ensemble_paths=mpnn_paths,
mpnn_device=device.type,
)
model = model.to(device)
# 第一个 fold 时做 cache warmup
if fold_idx == 0:
warmup_cache(model, list(all_smiles), batch_size=256)
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# 训练
result = train_fold(
fold_idx=fold_idx,
train_loader=train_loader,
val_loader=val_loader,
model=model,
device=device,
output_dir=output_dir,
lr=lr,
weight_decay=weight_decay,
epochs=epochs,
patience=patience,
config=config,
)
fold_results.append(result)
# 汇总结果
logger.info("\n" + "=" * 60)
logger.info("CROSS-VALIDATION TRAINING COMPLETE")
logger.info("=" * 60)
val_losses = [r["best_val_loss"] for r in fold_results]
val_rmses = [r["best_val_rmse"] for r in fold_results]
val_r2s = [r["best_val_r2"] for r in fold_results]
logger.info(f"\n[Per-Fold Results]")
for r in fold_results:
logger.info(
f" Fold {r['fold_idx']}: "
f"Val Loss={r['best_val_loss']:.4f}, "
f"RMSE={r['best_val_rmse']:.4f}, "
f"R²={r['best_val_r2']:.4f}, "
f"Epochs={r['epochs_trained']}"
)
logger.info(f"\n[Summary Statistics]")
logger.info(f" Val Loss: {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}")
logger.info(f" Val RMSE: {np.mean(val_rmses):.4f} ± {np.std(val_rmses):.4f}")
logger.info(f" Val R²: {np.mean(val_r2s):.4f} ± {np.std(val_r2s):.4f}")
# 保存 CV 结果
cv_results = {
"fold_results": fold_results,
"summary": {
"val_loss_mean": float(np.mean(val_losses)),
"val_loss_std": float(np.std(val_losses)),
"val_rmse_mean": float(np.mean(val_rmses)),
"val_rmse_std": float(np.std(val_rmses)),
"val_r2_mean": float(np.mean(val_r2s)),
"val_r2_std": float(np.std(val_r2s)),
},
"config": config,
}
results_path = output_dir / "cv_results.json"
with open(results_path, "w") as f:
json.dump(cv_results, f, indent=2)
logger.success(f"Saved CV results to {results_path}")
@app.command()
def test(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
model_dir: Path = MODELS_DIR / "pretrain_cv",
output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json",
batch_size: int = 64,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
在测试集上评估 CV 预训练模型
使用每个 fold 的模型在对应的测试集上评估
"""
logger.info(f"Using device: {device}")
device = torch.device(device)
# 查找所有 fold 目录
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds")
fold_results = []
all_preds = []
all_targets = []
for fold_dir in fold_dirs:
fold_idx = int(fold_dir.name.split("_")[1])
model_path = model_dir / f"fold_{fold_idx}" / "model.pt"
test_path = fold_dir / "test.parquet"
if not model_path.exists():
logger.warning(f"Fold {fold_idx}: model not found at {model_path}, skipping")
continue
if not test_path.exists():
logger.warning(f"Fold {fold_idx}: test data not found at {test_path}, skipping")
continue
# 加载模型
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint["config"]
use_mpnn = config.get("use_mpnn", False)
mpnn_paths = config.get("mpnn_ensemble_paths")
if use_mpnn and not mpnn_paths:
mpnn_paths = find_mpnn_ensemble_paths()
model = create_model(
d_model=config["d_model"],
num_heads=config["num_heads"],
n_attn_layers=config["n_attn_layers"],
fusion_strategy=config["fusion_strategy"],
head_hidden_dim=config["head_hidden_dim"],
dropout=config["dropout"],
use_mpnn=use_mpnn,
mpnn_ensemble_paths=mpnn_paths,
mpnn_device=device.type,
)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
model.eval()
# 加载测试数据
test_df = pd.read_parquet(test_path)
test_dataset = ExternalDeliveryDataset(test_df)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 评估
fold_preds = []
fold_targets = []
with torch.no_grad():
for batch in test_loader:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = batch["targets"]["delivery"].to(device)
mask = batch["mask"]["delivery"].to(device)
pred = model.forward_delivery(smiles, tabular).squeeze(-1)
if mask.any():
fold_preds.extend(pred[mask].cpu().numpy().tolist())
fold_targets.extend(targets[mask].cpu().numpy().tolist())
# 计算 fold 指标
fold_preds = np.array(fold_preds)
fold_targets = np.array(fold_targets)
mse = float(mean_squared_error(fold_targets, fold_preds))
rmse = float(np.sqrt(mse))
r2 = float(r2_score(fold_targets, fold_preds))
mae = float(np.mean(np.abs(fold_targets - fold_preds)))
corr = float(np.corrcoef(fold_targets, fold_preds)[0, 1])
fold_results.append({
"fold_idx": fold_idx,
"n_samples": len(fold_preds),
"mse": mse,
"rmse": rmse,
"mae": mae,
"r2": r2,
"correlation": corr,
})
all_preds.extend(fold_preds.tolist())
all_targets.extend(fold_targets.tolist())
logger.info(
f"Fold {fold_idx}: n={len(fold_preds)}, "
f"RMSE={rmse:.4f}, R²={r2:.4f}, MAE={mae:.4f}, Corr={corr:.4f}"
)
# 计算整体指标
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)
overall_mse = float(mean_squared_error(all_targets, all_preds))
overall_rmse = float(np.sqrt(overall_mse))
overall_r2 = float(r2_score(all_targets, all_preds))
overall_mae = float(np.mean(np.abs(all_targets - all_preds)))
overall_corr = float(np.corrcoef(all_targets, all_preds)[0, 1])
# 汇总统计
rmses = [r["rmse"] for r in fold_results]
r2s = [r["r2"] for r in fold_results]
logger.info("\n" + "=" * 60)
logger.info("CV TEST EVALUATION RESULTS")
logger.info("=" * 60)
logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]")
logger.info(f" RMSE: {np.mean(rmses):.4f} ± {np.std(rmses):.4f}")
logger.info(f" R²: {np.mean(r2s):.4f} ± {np.std(r2s):.4f}")
logger.info(f"\n[Overall (all {len(all_preds)} samples pooled)]")
logger.info(f" RMSE: {overall_rmse:.4f}")
logger.info(f" R²: {overall_r2:.4f}")
logger.info(f" MAE: {overall_mae:.4f}")
logger.info(f" Correlation: {overall_corr:.4f}")
# 保存结果
results = {
"fold_results": fold_results,
"summary_stats": {
"rmse_mean": float(np.mean(rmses)),
"rmse_std": float(np.std(rmses)),
"r2_mean": float(np.mean(r2s)),
"r2_std": float(np.std(r2s)),
},
"overall": {
"n_samples": len(all_preds),
"mse": overall_mse,
"rmse": overall_rmse,
"mae": overall_mae,
"r2": overall_r2,
"correlation": overall_corr,
},
}
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
logger.success(f"Saved test results to {output_path}")
if __name__ == "__main__":
app()

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -0,0 +1,21 @@
{
"d_model": 256,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.1,
"use_mpnn": true,
"mpnn_ensemble_paths": [
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt",
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt",
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt",
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt",
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt"
],
"lr": 0.0001,
"weight_decay": 1e-05,
"batch_size": 64,
"epochs": 50,
"patience": 10
}

View File

@ -0,0 +1,154 @@
[
{
"epoch": 1,
"train_loss": 0.7841362829904129,
"val_loss": 0.9230800325498075,
"val_rmse": 0.9607705443260125,
"val_r2": 0.20751899565844845,
"lr": 0.0001
},
{
"epoch": 2,
"train_loss": 0.6463805462366321,
"val_loss": 0.9147881584284676,
"val_rmse": 0.9564455868655657,
"val_r2": 0.2146377239324061,
"lr": 0.0001
},
{
"epoch": 3,
"train_loss": 0.5832159742587586,
"val_loss": 0.8564491166460212,
"val_rmse": 0.9254453497069719,
"val_r2": 0.2647228727252957,
"lr": 0.0001
},
{
"epoch": 4,
"train_loss": 0.5554495169353261,
"val_loss": 0.9521924139543092,
"val_rmse": 0.9758034723689635,
"val_r2": 0.18252548972094385,
"lr": 0.0001
},
{
"epoch": 5,
"train_loss": 0.5284749799126208,
"val_loss": 0.9002334602240055,
"val_rmse": 0.9488063310427656,
"val_r2": 0.22713320447964735,
"lr": 0.0001
},
{
"epoch": 6,
"train_loss": 0.5008455182478028,
"val_loss": 0.8869625280360692,
"val_rmse": 0.9417868917406855,
"val_r2": 0.23852651728338659,
"lr": 0.0001
},
{
"epoch": 7,
"train_loss": 0.4812185961924912,
"val_loss": 0.8656982819227316,
"val_rmse": 0.9304290898141719,
"val_r2": 0.25678226981933505,
"lr": 0.0001
},
{
"epoch": 8,
"train_loss": 0.46716415395603483,
"val_loss": 0.8974583646500499,
"val_rmse": 0.9473427898099364,
"val_r2": 0.22951567180333565,
"lr": 0.0001
},
{
"epoch": 9,
"train_loss": 0.44227657012969945,
"val_loss": 0.8423525478929029,
"val_rmse": 0.917797665041863,
"val_r2": 0.2768250098825846,
"lr": 0.0001
},
{
"epoch": 10,
"train_loss": 0.4192214831283541,
"val_loss": 0.8834660291599854,
"val_rmse": 0.9399287338143215,
"val_r2": 0.24152834743071827,
"lr": 0.0001
},
{
"epoch": 11,
"train_loss": 0.4261854484762579,
"val_loss": 0.8901423802745356,
"val_rmse": 0.9434735654804721,
"val_r2": 0.23579658456782115,
"lr": 0.0001
},
{
"epoch": 12,
"train_loss": 0.40934500416213515,
"val_loss": 0.8902992367456848,
"val_rmse": 0.9435566978740579,
"val_r2": 0.2356619059496775,
"lr": 0.0001
},
{
"epoch": 13,
"train_loss": 0.3993083212922134,
"val_loss": 0.8952986713233702,
"val_rmse": 0.946202241268488,
"val_r2": 0.23136979638836497,
"lr": 0.0001
},
{
"epoch": 14,
"train_loss": 0.38611710345412054,
"val_loss": 0.898042505591281,
"val_rmse": 0.9476510417548493,
"val_r2": 0.22901418082171254,
"lr": 0.0001
},
{
"epoch": 15,
"train_loss": 0.3662278820512592,
"val_loss": 0.8985377061208856,
"val_rmse": 0.9479122944969771,
"val_r2": 0.2285890244824833,
"lr": 0.0001
},
{
"epoch": 16,
"train_loss": 0.33336883667983097,
"val_loss": 0.8594161927771942,
"val_rmse": 0.9270470292726928,
"val_r2": 0.26217556409941023,
"lr": 5e-05
},
{
"epoch": 17,
"train_loss": 0.33174229304183017,
"val_loss": 0.8942263270711439,
"val_rmse": 0.945635409387207,
"val_r2": 0.23229043171319064,
"lr": 5e-05
},
{
"epoch": 18,
"train_loss": 0.3200935872264761,
"val_loss": 0.8843722421096231,
"val_rmse": 0.9404106786278892,
"val_r2": 0.24075034122444328,
"lr": 5e-05
},
{
"epoch": 19,
"train_loss": 0.3140528901624022,
"val_loss": 0.8707393039336969,
"val_rmse": 0.9331341268749274,
"val_r2": 0.2524544731269053,
"lr": 5e-05
}
]

Binary file not shown.

View File

@ -0,0 +1,130 @@
[
{
"epoch": 1,
"train_loss": 0.7586579631889938,
"val_loss": 0.7181247283430661,
"val_rmse": 0.8474223956247073,
"val_r2": 0.27072978816447013,
"lr": 0.0001
},
{
"epoch": 2,
"train_loss": 0.6169086206378913,
"val_loss": 0.7122436411609591,
"val_rmse": 0.8439452896464879,
"val_r2": 0.27670212861312427,
"lr": 0.0001
},
{
"epoch": 3,
"train_loss": 0.5716009976577638,
"val_loss": 0.724399374180903,
"val_rmse": 0.8511165408178202,
"val_r2": 0.2643577544134178,
"lr": 0.0001
},
{
"epoch": 4,
"train_loss": 0.5406081575331337,
"val_loss": 0.8086859301147815,
"val_rmse": 0.8992696715048705,
"val_r2": 0.17876302728821758,
"lr": 0.0001
},
{
"epoch": 5,
"train_loss": 0.5032305184360722,
"val_loss": 0.7551608490501026,
"val_rmse": 0.8689999128971181,
"val_r2": 0.23311884509134373,
"lr": 0.0001
},
{
"epoch": 6,
"train_loss": 0.49779197957664073,
"val_loss": 0.7009093638175043,
"val_rmse": 0.8372033005770025,
"val_r2": 0.2882123252930564,
"lr": 0.0001
},
{
"epoch": 7,
"train_loss": 0.4720957023516007,
"val_loss": 0.7939711763762837,
"val_rmse": 0.8910506016905562,
"val_r2": 0.1937061718828037,
"lr": 0.0001
},
{
"epoch": 8,
"train_loss": 0.45880254612337634,
"val_loss": 0.7781202286020521,
"val_rmse": 0.8821112329914439,
"val_r2": 0.209803130396238,
"lr": 0.0001
},
{
"epoch": 9,
"train_loss": 0.43299420961863133,
"val_loss": 0.7305147618332145,
"val_rmse": 0.8547015607274253,
"val_r2": 0.25814744997562133,
"lr": 0.0001
},
{
"epoch": 10,
"train_loss": 0.4271428178197404,
"val_loss": 0.736324011356838,
"val_rmse": 0.8580932378454483,
"val_r2": 0.25224804192225836,
"lr": 0.0001
},
{
"epoch": 11,
"train_loss": 0.4049153788182662,
"val_loss": 0.7328011674777642,
"val_rmse": 0.8560380622891226,
"val_r2": 0.2558255581383374,
"lr": 0.0001
},
{
"epoch": 12,
"train_loss": 0.4011265030265379,
"val_loss": 0.8308254960889787,
"val_rmse": 0.9114962980441127,
"val_r2": 0.15627985591441118,
"lr": 0.0001
},
{
"epoch": 13,
"train_loss": 0.36400350090994316,
"val_loss": 0.7314743397036573,
"val_rmse": 0.8552627260066816,
"val_r2": 0.25717298455584947,
"lr": 5e-05
},
{
"epoch": 14,
"train_loss": 0.35224626399225445,
"val_loss": 0.7428758944520272,
"val_rmse": 0.8619024879725893,
"val_r2": 0.24559446076977343,
"lr": 5e-05
},
{
"epoch": 15,
"train_loss": 0.33604996105846857,
"val_loss": 0.7853316709722159,
"val_rmse": 0.8861894055728078,
"val_r2": 0.20247977173777976,
"lr": 5e-05
},
{
"epoch": 16,
"train_loss": 0.3375160908226994,
"val_loss": 0.7890364486366602,
"val_rmse": 0.888277234295553,
"val_r2": 0.19871749006650596,
"lr": 5e-05
}
]

Binary file not shown.

Binary file not shown.

View File

@ -1,285 +1,309 @@
{ {
"train": [ "train": [
{ {
"loss": 0.8151603855680482, "loss": 0.8244676398801744,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.6867990841792819, "loss": 0.6991508170533461,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.645540308540888, "loss": 0.6388374940987616,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.5923176541020599, "loss": 0.6008581508669937,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.5720762926262872, "loss": 0.584832567446085,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.5477570670417328, "loss": 0.5481657371815157,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.5280393017717573, "loss": 0.5368926340308079,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.5122504676513313, "loss": 0.5210388793613561,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.49667307051028314, "loss": 0.49758357966374045,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.486139440352648, "loss": 0.49256294099457043,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.4749755339466122, "loss": 0.4697267088016886,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.4636757543530298, "loss": 0.45763822707571084,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.4543497681877452, "loss": 0.4495221330627172,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.4408158337956461, "loss": 0.446159594079631,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.4419790126221837, "loss": 0.4327090857889029,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.42850686623585116, "loss": 0.4249273364101852,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.41607048387867007, "loss": 0.4216959138704459,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.427172136486513, "loss": 0.416526201182502,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.4125568530569382, "loss": 0.40368679039741573,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.39480836287767923, "loss": 0.4051084730032182,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.3885056775666858, "loss": 0.38971701020385785,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.3894976457588827, "loss": 0.39155546386038786,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.3890058272899995, "loss": 0.37976963541784114,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.3741690826284791, "loss": 0.36484339719805037,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.3534914434345719, "loss": 0.36232607571196496,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.3349389765134386, "loss": 0.3345973272380199,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.32965143874976194, "loss": 0.31767916518768957,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.32094062546116675, "loss": 0.32065429246052457,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.32526135008251184, "loss": 0.3171297926146043,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.31289531808423826, "loss": 0.3122120894173009,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.3088379208288558, "loss": 0.3135035038404461,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.2994744991261045, "loss": 0.2987745178222875,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.2981521815160671, "loss": 0.2914867957853393,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.29143649446979303, "loss": 0.2983839795507705,
"n_samples": 8721 "n_samples": 8721
}, },
{ {
"loss": 0.29075756723379653, "loss": 0.2826709597875678,
"n_samples": 8721
},
{
"loss": 0.2731766632569382,
"n_samples": 8721
},
{
"loss": 0.27726896305742266,
"n_samples": 8721
},
{
"loss": 0.27864557847067956,
"n_samples": 8721 "n_samples": 8721
} }
], ],
"val": [ "val": [
{ {
"loss": 0.7625711447683281, "loss": 0.7601077516012517,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.7092331695236781, "loss": 0.7119935319611901,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.7014068689723995, "loss": 0.6461842978148269,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6595172673863646, "loss": 0.7006978391063226,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6312279044905191, "loss": 0.6533874032943979,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6349272860831151, "loss": 0.6413641451743611,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6587623598744133, "loss": 0.6168395132979742,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6093261837651732, "loss": 0.6095251602162025,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6125607111474924, "loss": 0.5887809592626905,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6005943137518024, "loss": 0.5655298325376368,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6876292386783289, "loss": 0.5809201743872788,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5940848466228036, "loss": 0.5897585974912033,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5820883587079644, "loss": 0.5732012489662573,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6302792748938035, "loss": 0.5607388911786094,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5849901610914275, "loss": 0.5717580675371414,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5830434826428553, "loss": 0.5553950037657291,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5643168952858116, "loss": 0.5778171792857049,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5592790719340829, "loss": 0.5602665468127734,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.600335100686833, "loss": 0.5475307451359259,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5646457721097674, "loss": 0.551515599314827,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.6288956836004376, "loss": 0.5755438121541243,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5771863222183704, "loss": 0.5798238261811381,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5738056593250687, "loss": 0.5739961433828923,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5636531712085593, "loss": 0.5742599932540312,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5465074849879163, "loss": 0.5834948123885382,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5701294839843508, "loss": 0.554078846570139,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5570075802420438, "loss": 0.5714933996322354,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5711473401701241, "loss": 0.5384107524350331,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5576858864741552, "loss": 0.570854394451568,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5624132422716871, "loss": 0.5767292551642478,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5655298555506272, "loss": 0.5660079547556808,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5568078993151677, "loss": 0.5608972411514312,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.567752199383958, "loss": 0.5620947442987263,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5683093779442603, "loss": 0.5706970894361305,
"n_samples": 969 "n_samples": 969
}, },
{ {
"loss": 0.5741443767974497, "loss": 0.5702376298690974,
"n_samples": 969
},
{
"loss": 0.5758474825259579,
"n_samples": 969
},
{
"loss": 0.5673816067284844,
"n_samples": 969
},
{
"loss": 0.5671441179879925,
"n_samples": 969 "n_samples": 969
} }
] ]

View File

@ -0,0 +1,18 @@
{
"model_path": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/pretrain_delivery.pt",
"val_path": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/data/processed/val_pretrain.parquet",
"n_samples": 969,
"metrics": {
"mse": 0.5384107828140259,
"rmse": 0.7337648007461423,
"mae": 0.5158410668373108,
"r2": 0.4877620374678041,
"correlation": 0.7059544816157891
},
"statistics": {
"y_true_mean": 0.009126194752752781,
"y_true_std": 1.0252292156219482,
"y_pred_mean": 0.008337541483342648,
"y_pred_std": 0.8293642997741699
}
}

234
scripts/process_data_cv.py Normal file
View File

@ -0,0 +1,234 @@
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
from pathlib import Path
from typing import Dict, List, Tuple
import pandas as pd
import typer
from loguru import logger
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import (
LNPDatasetConfig,
COMP_COLS,
HELP_COLS,
get_phys_cols,
get_exp_cols,
EXP_ONEHOT_SPECS,
PHYS_ONEHOT_SPECS,
)
app = typer.Typer()
# CV extra_x 列名到模型列名的映射
CV_COL_MAPPING = {
# Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related)
"Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded",
"Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual",
# Helper_lipid_ID_None 不在模型中使用,忽略
}
def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
加载单个 CV split 的数据
Args:
cv_dir: CV split 目录 cv_0/
Returns:
(train_df, valid_df, test_df) 合并后的 DataFrame
"""
splits = {}
for split_name in ["train", "valid", "test"]:
# 加载主数据smiles, quantified_delivery
main_path = cv_dir / f"{split_name}.csv"
extra_x_path = cv_dir / f"{split_name}_extra_x.csv"
metadata_path = cv_dir / f"{split_name}_metadata.csv"
if not main_path.exists():
raise FileNotFoundError(f"Missing {main_path}")
main_df = pd.read_csv(main_path)
# 加载 extra_x已 one-hot 编码的特征)
if extra_x_path.exists():
extra_x_df = pd.read_csv(extra_x_path)
# 确保行数一致
assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}"
# 合并(按行索引)
df = pd.concat([main_df, extra_x_df], axis=1)
else:
df = main_df
logger.warning(f" {split_name}_extra_x.csv not found, using main data only")
# 可选:从 metadata 获取额外信息
if metadata_path.exists():
metadata_df = pd.read_csv(metadata_path)
# 提取需要的列(如 Purity, Mix_type, Value_name 等)
for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]:
if col in metadata_df.columns and col not in df.columns:
df[col] = metadata_df[col]
splits[split_name] = df
return splits["train"], splits["valid"], splits["test"]
def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
处理 CV 数据的 DataFrame对齐到模型所需的列格式
CV 数据的 extra_x 已经包含大部分 one-hot 编码但需要
1. 添加缺失的 one-hot 设为 0
2. metadata 中生成 phys token one-hot Purity, Mix_type, Cargo_type, Target_or_delivered_gene
3. 生成 Value_name one-hot
"""
df = df.copy()
# 1. 处理 comp 列
for col in COMP_COLS:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
else:
df[col] = 0.0
# 2. 处理 help 列
for col in HELP_COLS:
if col not in df.columns:
df[col] = 0.0
else:
df[col] = df[col].fillna(0.0).astype(float)
# 3. 处理 phys token 的 one-hot 列
for col, values in PHYS_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 4. 处理 exp token 的 one-hot 列
for col, values in EXP_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 5. 处理 quantified_delivery
if "quantified_delivery" in df.columns:
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
return df
def get_feature_columns() -> List[str]:
"""获取所有特征列名"""
config = LNPDatasetConfig()
return (
["smiles"]
+ config.comp_cols
+ config.phys_cols
+ config.help_cols
+ config.exp_cols
+ ["quantified_delivery"]
)
@app.command()
def main(
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
output_dir: Path = PROCESSED_DATA_DIR / "cv",
n_folds: int = 5,
):
"""
处理 cross-validation 数据生成模型所需的 parquet 文件
输出结构:
- processed/cv/fold_0/train.parquet
- processed/cv/fold_0/valid.parquet
- processed/cv/fold_0/test.parquet
- processed/cv/fold_1/...
- processed/cv/feature_columns.txt
"""
logger.info(f"Processing CV data from {data_dir}")
# 获取所有 cv_* 目录
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
if len(cv_dirs) == 0:
logger.error(f"No cv_* directories found in {data_dir}")
raise typer.Exit(1)
if len(cv_dirs) != n_folds:
logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}")
logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}")
feature_cols = get_feature_columns()
output_dir.mkdir(parents=True, exist_ok=True)
for i, cv_dir in enumerate(cv_dirs):
fold_name = f"fold_{i}"
fold_output_dir = output_dir / fold_name
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Processing {cv_dir.name} -> {fold_name}")
# 加载数据
train_df, valid_df, test_df = load_cv_split(cv_dir)
logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}")
# 处理数据
train_df = process_cv_dataframe(train_df)
valid_df = process_cv_dataframe(valid_df)
test_df = process_cv_dataframe(test_df)
# 确保所有列存在
for col in feature_cols:
for df in [train_df, valid_df, test_df]:
if col not in df.columns:
df[col] = 0.0 if col != "smiles" else ""
# 只保留需要的列
train_df = train_df[feature_cols]
valid_df = valid_df[feature_cols]
test_df = test_df[feature_cols]
# 保存
train_df.to_parquet(fold_output_dir / "train.parquet", index=False)
valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False)
test_df.to_parquet(fold_output_dir / "test.parquet", index=False)
logger.success(f" Saved to {fold_output_dir}")
# 保存特征列配置
cols_path = output_dir / "feature_columns.txt"
with open(cols_path, "w") as f:
f.write("\n".join(feature_cols))
logger.success(f"Saved feature columns to {cols_path}")
logger.info("\n" + "=" * 60)
logger.info("CV DATA PROCESSING COMPLETE")
logger.info("=" * 60)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Number of folds: {len(cv_dirs)}")
if __name__ == "__main__":
app()