mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-01-19 11:53:13 +08:00
...
This commit is contained in:
parent
96a27caab2
commit
0290649df1
29
Makefile
29
Makefile
@ -68,15 +68,38 @@ clean_data: requirements
|
||||
data: requirements
|
||||
$(PYTHON_INTERPRETER) scripts/process_data.py
|
||||
|
||||
## Train model
|
||||
## Process external data for pretrain (external -> processed)
|
||||
.PHONY: data_pretrain
|
||||
data_pretrain: requirements
|
||||
$(PYTHON_INTERPRETER) scripts/process_external.py
|
||||
|
||||
# MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder
|
||||
# 例如:make pretrain USE_MPNN=1
|
||||
MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,)
|
||||
|
||||
# Backbone 冻结:使用 FREEZE_BACKBONE=1 冻结 backbone,只训练 heads
|
||||
# 例如:make finetune FREEZE_BACKBONE=1
|
||||
FREEZE_FLAG = $(if $(FREEZE_BACKBONE),--freeze-backbone,)
|
||||
|
||||
## Pretrain on external data (delivery only)
|
||||
.PHONY: pretrain
|
||||
pretrain: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain $(MPNN_FLAG)
|
||||
|
||||
## Train model (multi-task, from scratch)
|
||||
.PHONY: train
|
||||
train: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train $(MPNN_FLAG)
|
||||
|
||||
## Finetune from pretrained checkpoint (use FREEZE_BACKBONE=1 to freeze backbone)
|
||||
.PHONY: finetune
|
||||
finetune: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG)
|
||||
|
||||
## Train with hyperparameter tuning
|
||||
.PHONY: tune
|
||||
tune: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG)
|
||||
|
||||
## Run predictions
|
||||
.PHONY: predict
|
||||
|
||||
101
README.md
101
README.md
@ -4,7 +4,106 @@
|
||||
<img src="https://img.shields.io/badge/CCDS-Project%20template-328F97?logo=cookiecutter" />
|
||||
</a>
|
||||
|
||||
A short description of the project.
|
||||
LNP(脂质纳米颗粒)药物递送性能预测模型。
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装环境
|
||||
|
||||
```bash
|
||||
pixi install
|
||||
pixi shell
|
||||
```
|
||||
|
||||
### 2. 数据处理
|
||||
|
||||
```bash
|
||||
# 清洗原始数据 (raw -> interim)
|
||||
make clean_data
|
||||
|
||||
# 处理内部数据集 (interim -> processed)
|
||||
make data
|
||||
|
||||
# 处理外部预训练数据 (external -> processed)
|
||||
make data_pretrain
|
||||
```
|
||||
|
||||
### 3. 训练模型
|
||||
|
||||
**方式 A:直接训练(从零开始)**
|
||||
|
||||
```bash
|
||||
make train
|
||||
```
|
||||
|
||||
**方式 B:预训练 + 微调(推荐)**
|
||||
|
||||
利用外部 LiON 数据集(约 9000 条)进行预训练,再在内部数据上微调:
|
||||
|
||||
```bash
|
||||
# Step 1: 处理外部数据
|
||||
make data_pretrain
|
||||
|
||||
# Step 2: 在外部数据上预训练 backbone + delivery head
|
||||
make pretrain
|
||||
|
||||
# Step 3: 加载预训练权重,在内部数据上多任务微调
|
||||
make finetune
|
||||
```
|
||||
|
||||
**方式 C:超参数调优**
|
||||
|
||||
```bash
|
||||
make tune
|
||||
```
|
||||
|
||||
### 4. 测试与预测
|
||||
|
||||
```bash
|
||||
# 在测试集上评估
|
||||
make test
|
||||
|
||||
# 生成预测结果
|
||||
make predict
|
||||
```
|
||||
|
||||
## 训练流程详解
|
||||
|
||||
### 预训练 (Pretrain)
|
||||
|
||||
在外部 LiON 数据上,仅训练 `quantified_delivery` 任务:
|
||||
|
||||
```bash
|
||||
# 1. 先处理外部数据
|
||||
python scripts/process_external.py
|
||||
|
||||
# 2. 预训练
|
||||
python -m lnp_ml.modeling.pretrain \
|
||||
--train-path data/processed/train_pretrain.parquet \
|
||||
--val-path data/processed/val_pretrain.parquet \
|
||||
--epochs 50 \
|
||||
--lr 1e-4
|
||||
```
|
||||
|
||||
产出:
|
||||
- `data/processed/train_pretrain.parquet`: 处理后的训练数据
|
||||
- `data/processed/val_pretrain.parquet`: 处理后的验证数据
|
||||
- `models/pretrain_delivery.pt`: backbone + delivery head 权重
|
||||
- `models/pretrain_history.json`: 训练历史
|
||||
|
||||
### 微调 (Finetune)
|
||||
|
||||
加载预训练权重,在内部多任务数据上训练:
|
||||
|
||||
```bash
|
||||
python -m lnp_ml.modeling.train \
|
||||
--init-from-pretrain models/pretrain_delivery.pt \
|
||||
--load-delivery-head # 可选:是否加载 delivery head 权重
|
||||
```
|
||||
|
||||
产出:
|
||||
- `models/model.pt`: 完整模型权重
|
||||
- `models/history.json`: 训练历史
|
||||
|
||||
## Project Organization
|
||||
|
||||
|
||||
@ -346,3 +346,185 @@ def load_dataset(
|
||||
LNPDataset(val_df, config),
|
||||
LNPDataset(test_df, config),
|
||||
)
|
||||
|
||||
|
||||
# ============ 外部数据(仅 delivery)处理 ============
|
||||
|
||||
# 外部数据中 Value_name 的值映射(空格 -> 下划线)
|
||||
EXTERNAL_VALUE_NAME_MAP = {
|
||||
"log luminescence": "log_luminescence",
|
||||
"Luminescence": "luminescence",
|
||||
"FFL silencing": "FFL_silencing",
|
||||
"Peptide abundance": "Peptide_abundance",
|
||||
"hEPO": "hEPO",
|
||||
"FVII silencing": "FVII_silencing",
|
||||
"GFP delivery": "GFP_delivery",
|
||||
"Discretized luminescence": "Discretized_luminescence",
|
||||
}
|
||||
|
||||
# 外部数据中 Mix_type 的值映射
|
||||
EXTERNAL_MIX_TYPE_MAP = {
|
||||
"Hand": "Pipetting", # 外部数据用 "Hand" 表示 "Pipetting"
|
||||
"Microfluidic": "Microfluidic",
|
||||
"Pipetting": "Pipetting",
|
||||
}
|
||||
|
||||
|
||||
def process_external_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
处理外部 LiON 数据的 DataFrame,对齐到模型所需的列格式。
|
||||
|
||||
与 process_dataframe 类似,但针对外部数据的列名差异做适配:
|
||||
- Value_name 值中的空格需要转为下划线
|
||||
- Mix_type 中 "Hand" 需要映射为 "Pipetting"
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# 1. 预处理:映射 Value_name 和 Mix_type 的值
|
||||
if "Value_name" in df.columns:
|
||||
df["Value_name"] = df["Value_name"].map(
|
||||
lambda x: EXTERNAL_VALUE_NAME_MAP.get(x, x) if pd.notna(x) else x
|
||||
)
|
||||
|
||||
if "Mix_type" in df.columns:
|
||||
df["Mix_type"] = df["Mix_type"].map(
|
||||
lambda x: EXTERNAL_MIX_TYPE_MAP.get(x, x) if pd.notna(x) else x
|
||||
)
|
||||
|
||||
# 2. 处理 phys token 的 one-hot 列(如果不存在则从原始列生成)
|
||||
for col, values in PHYS_ONEHOT_SPECS.items():
|
||||
for v in values:
|
||||
onehot_col = f"{col}_{v}"
|
||||
if onehot_col not in df.columns:
|
||||
if col in df.columns:
|
||||
df[onehot_col] = (df[col] == v).astype(float)
|
||||
else:
|
||||
df[onehot_col] = 0.0
|
||||
|
||||
# 3. 处理 exp token 的 one-hot 列
|
||||
# 外部数据部分列已存在(如 Model_type_*, Delivery_target_* 等),但可能缺失某些类别
|
||||
for col, values in EXP_ONEHOT_SPECS.items():
|
||||
for v in values:
|
||||
onehot_col = f"{col}_{v}"
|
||||
if onehot_col not in df.columns:
|
||||
if col in df.columns:
|
||||
df[onehot_col] = (df[col] == v).astype(float)
|
||||
else:
|
||||
df[onehot_col] = 0.0
|
||||
else:
|
||||
# 确保是 float 类型
|
||||
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
|
||||
|
||||
# 4. 确保 comp 列存在且为 float
|
||||
for col in COMP_COLS:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
||||
else:
|
||||
df[col] = 0.0
|
||||
|
||||
# 5. 确保 help 列存在
|
||||
for col in HELP_COLS:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
else:
|
||||
df[col] = df[col].fillna(0.0).astype(float)
|
||||
|
||||
# 6. 处理 quantified_delivery
|
||||
if "quantified_delivery" in df.columns:
|
||||
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class ExternalDeliveryDataset(Dataset):
|
||||
"""
|
||||
外部 LiON 数据集,仅用于 delivery 预训练。
|
||||
|
||||
返回:
|
||||
- smiles: str
|
||||
- tabular: Dict[str, Tensor] with keys "comp", "phys", "help", "exp"
|
||||
- targets: Dict[str, Tensor] with key "delivery"
|
||||
- mask: Dict[str, Tensor] with key "delivery"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
config: Optional[LNPDatasetConfig] = None,
|
||||
):
|
||||
self.config = config or LNPDatasetConfig()
|
||||
self.df = process_external_dataframe(df)
|
||||
|
||||
# 提取数据
|
||||
self.smiles = self.df[SMILES_COL].tolist()
|
||||
|
||||
# Tabular features
|
||||
self.comp = self.df[self.config.comp_cols].values.astype(np.float32)
|
||||
self.phys = self.df[self.config.phys_cols].values.astype(np.float32)
|
||||
self.help = self.df[self.config.help_cols].values.astype(np.float32)
|
||||
self.exp = self.df[self.config.exp_cols].values.astype(np.float32)
|
||||
|
||||
# 只有 delivery 作为 target
|
||||
self.delivery = self.df["quantified_delivery"].values.astype(np.float32) if "quantified_delivery" in self.df.columns else None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.smiles)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
item = {
|
||||
"smiles": self.smiles[idx],
|
||||
"tabular": {
|
||||
"comp": torch.from_numpy(self.comp[idx]),
|
||||
"phys": torch.from_numpy(self.phys[idx]),
|
||||
"help": torch.from_numpy(self.help[idx]),
|
||||
"exp": torch.from_numpy(self.exp[idx]),
|
||||
},
|
||||
"targets": {},
|
||||
"mask": {},
|
||||
}
|
||||
|
||||
# delivery target and mask
|
||||
if self.delivery is not None:
|
||||
item["targets"]["delivery"] = torch.tensor(self.delivery[idx], dtype=torch.float32)
|
||||
item["mask"]["delivery"] = torch.tensor(not np.isnan(self.delivery[idx]), dtype=torch.bool)
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def load_external_dataset(
|
||||
path: Path,
|
||||
train_ratio: float = 0.9,
|
||||
seed: int = 42,
|
||||
) -> Tuple[ExternalDeliveryDataset, ExternalDeliveryDataset]:
|
||||
"""
|
||||
加载外部 LiON 数据集并划分为 train/val。
|
||||
|
||||
Args:
|
||||
path: CSV 文件路径
|
||||
train_ratio: 训练集比例(剩余为验证集)
|
||||
seed: 随机种子
|
||||
|
||||
Returns:
|
||||
(train_dataset, val_dataset)
|
||||
"""
|
||||
df = pd.read_csv(path)
|
||||
|
||||
# 过滤掉 quantified_delivery 为空的行
|
||||
if "quantified_delivery" in df.columns:
|
||||
df = df[df["quantified_delivery"].notna()].reset_index(drop=True)
|
||||
|
||||
# 随机打乱
|
||||
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
|
||||
|
||||
n = len(df)
|
||||
n_train = int(n * train_ratio)
|
||||
|
||||
train_df = df.iloc[:n_train]
|
||||
val_df = df.iloc[n_train:]
|
||||
|
||||
config = LNPDatasetConfig()
|
||||
|
||||
return (
|
||||
ExternalDeliveryDataset(train_df, config),
|
||||
ExternalDeliveryDataset(val_df, config),
|
||||
)
|
||||
|
||||
@ -43,7 +43,13 @@ class RDKitFeaturizer:
|
||||
return np.array(MACCSkeys.GenMACCSKeys(mol).ToList(), dtype=np.float32)
|
||||
|
||||
def _encode_desc(self, mol: Chem.Mol) -> np.ndarray:
|
||||
return np.array(list(Descriptors.CalcMolDescriptors(mol).values()), dtype=np.float32)
|
||||
# 使用 float64 计算,然后 clip 到 float32 范围,避免 overflow
|
||||
desc_values = list(Descriptors.CalcMolDescriptors(mol).values())
|
||||
arr = np.array(desc_values, dtype=np.float64)
|
||||
# 替换 inf/nan,clip 到 float32 范围
|
||||
arr = np.nan_to_num(arr, nan=0.0, posinf=1e10, neginf=-1e10)
|
||||
arr = np.clip(arr, -1e10, 1e10)
|
||||
return arr.astype(np.float32)
|
||||
|
||||
def _encode_one(self, smiles: str) -> Dict[str, np.ndarray]:
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
|
||||
@ -121,28 +121,16 @@ class LNPModel(nn.Module):
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
def forward(
|
||||
def _encode_and_project(
|
||||
self,
|
||||
smiles: List[str],
|
||||
tabular: Dict[str, torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
smiles: SMILES 字符串列表,长度为 B
|
||||
tabular: Dict[str, Tensor],包含:
|
||||
- "comp": [B, 5] 配方比例
|
||||
- "phys": [B, 12] 物理参数
|
||||
- "help": [B, 4] Helper lipid
|
||||
- "exp": [B, 32] 实验条件
|
||||
内部方法:编码 SMILES 和 tabular,返回 stacked tokens。
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]:
|
||||
- "size": [B, 1]
|
||||
- "pdi": [B, 4]
|
||||
- "ee": [B, 3]
|
||||
- "delivery": [B, 1]
|
||||
- "biodist": [B, 7]
|
||||
- "toxic": [B, 2]
|
||||
stacked: [B, n_tokens, d_model]
|
||||
"""
|
||||
# 1. Encode SMILES
|
||||
rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"}
|
||||
@ -170,23 +158,87 @@ class LNPModel(nn.Module):
|
||||
projected = self.token_projector(all_features) # Dict[str, [B, d_model]]
|
||||
|
||||
# 4. Stack tokens: [B, n_tokens, d_model]
|
||||
# 按顺序 stack:Channel A (化学) + Channel B (配方/实验)
|
||||
if self.use_mpnn:
|
||||
token_order = ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||
else:
|
||||
token_order = ["morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||
|
||||
stacked = torch.stack([projected[k] for k in token_order], dim=1)
|
||||
return stacked
|
||||
|
||||
# 5. Cross Modal Attention
|
||||
def forward_backbone(
|
||||
self,
|
||||
smiles: List[str],
|
||||
tabular: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Backbone forward:编码 -> 投影 -> 注意力 -> 融合,不经过任务头。
|
||||
|
||||
用于 pretrain 阶段或需要提取特征的场景。
|
||||
|
||||
Args:
|
||||
smiles: SMILES 字符串列表,长度为 B
|
||||
tabular: Dict[str, Tensor]
|
||||
|
||||
Returns:
|
||||
fused: [B, fusion_dim] 融合后的特征向量
|
||||
"""
|
||||
# 编码 + 投影 + stack
|
||||
stacked = self._encode_and_project(smiles, tabular)
|
||||
|
||||
# Cross Modal Attention
|
||||
attended = self.cross_attention(stacked)
|
||||
|
||||
# 6. Fusion
|
||||
# Fusion
|
||||
fused = self.fusion(attended)
|
||||
|
||||
# 7. Multi-Task Head
|
||||
outputs = self.head(fused)
|
||||
return fused
|
||||
|
||||
def forward_delivery(
|
||||
self,
|
||||
smiles: List[str],
|
||||
tabular: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
仅预测 delivery(用于 pretrain)。
|
||||
|
||||
Args:
|
||||
smiles: SMILES 字符串列表,长度为 B
|
||||
tabular: Dict[str, Tensor]
|
||||
|
||||
Returns:
|
||||
delivery: [B, 1] 预测的 delivery 值
|
||||
"""
|
||||
fused = self.forward_backbone(smiles, tabular)
|
||||
return self.head.delivery_head(fused)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
smiles: List[str],
|
||||
tabular: Dict[str, torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
完整的多任务 forward。
|
||||
|
||||
Args:
|
||||
smiles: SMILES 字符串列表,长度为 B
|
||||
tabular: Dict[str, Tensor],包含:
|
||||
- "comp": [B, 5] 配方比例
|
||||
- "phys": [B, 12] 物理参数
|
||||
- "help": [B, 4] Helper lipid
|
||||
- "exp": [B, 32] 实验条件
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]:
|
||||
- "size": [B, 1]
|
||||
- "pdi": [B, 4]
|
||||
- "ee": [B, 3]
|
||||
- "delivery": [B, 1]
|
||||
- "biodist": [B, 7]
|
||||
- "toxic": [B, 2]
|
||||
"""
|
||||
fused = self.forward_backbone(smiles, tabular)
|
||||
outputs = self.head(fused)
|
||||
return outputs
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
@ -195,6 +247,69 @@ class LNPModel(nn.Module):
|
||||
if self.mpnn_encoder is not None:
|
||||
self.mpnn_encoder.clear_cache()
|
||||
|
||||
def get_backbone_state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
获取 backbone 部分的 state_dict(不含任务头)。
|
||||
|
||||
包含: token_projector, cross_attention, fusion
|
||||
"""
|
||||
backbone_keys = []
|
||||
for name in self.state_dict().keys():
|
||||
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
|
||||
backbone_keys.append(name)
|
||||
|
||||
return {k: v for k, v in self.state_dict().items() if k in backbone_keys}
|
||||
|
||||
def get_delivery_head_state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
"""获取 delivery head 的 state_dict"""
|
||||
return {
|
||||
k: v for k, v in self.state_dict().items()
|
||||
if k.startswith("head.delivery_head.")
|
||||
}
|
||||
|
||||
def load_pretrain_weights(
|
||||
self,
|
||||
pretrain_state_dict: Dict[str, torch.Tensor],
|
||||
load_delivery_head: bool = True,
|
||||
strict: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
从预训练 checkpoint 加载 backbone 和(可选)delivery head 权重。
|
||||
|
||||
Args:
|
||||
pretrain_state_dict: 预训练模型的 state_dict
|
||||
load_delivery_head: 是否加载 delivery head 权重
|
||||
strict: 是否严格匹配(默认 False,允许缺失/多余的键)
|
||||
"""
|
||||
# 筛选要加载的参数
|
||||
keys_to_load = []
|
||||
for name in pretrain_state_dict.keys():
|
||||
# Backbone 部分
|
||||
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
|
||||
keys_to_load.append(name)
|
||||
# Delivery head(可选)
|
||||
elif load_delivery_head and name.startswith("head.delivery_head."):
|
||||
keys_to_load.append(name)
|
||||
|
||||
filtered_state_dict = {k: v for k, v in pretrain_state_dict.items() if k in keys_to_load}
|
||||
|
||||
# 加载权重
|
||||
missing, unexpected = [], []
|
||||
model_state = self.state_dict()
|
||||
for k, v in filtered_state_dict.items():
|
||||
if k in model_state:
|
||||
if model_state[k].shape == v.shape:
|
||||
model_state[k] = v
|
||||
else:
|
||||
unexpected.append(f"{k} (shape mismatch: {model_state[k].shape} vs {v.shape})")
|
||||
else:
|
||||
unexpected.append(k)
|
||||
|
||||
self.load_state_dict(model_state, strict=False)
|
||||
|
||||
if strict and (missing or unexpected):
|
||||
raise RuntimeError(f"Missing keys: {missing}, Unexpected keys: {unexpected}")
|
||||
|
||||
|
||||
class LNPModelWithoutMPNN(LNPModel):
|
||||
"""不使用 MPNN 的简化版本"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""预测脚本:使用训练好的模型进行推理"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -11,17 +11,54 @@ import typer
|
||||
|
||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||
from lnp_ml.modeling.models import LNPModelWithoutMPNN
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
# MPNN ensemble 默认路径
|
||||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||
|
||||
def load_model(model_path: Path, device: torch.device) -> LNPModelWithoutMPNN:
|
||||
"""加载训练好的模型"""
|
||||
|
||||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
|
||||
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||
if not model_paths:
|
||||
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||
return [str(p) for p in model_paths]
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: Path,
|
||||
device: torch.device,
|
||||
mpnn_device: str = "cpu",
|
||||
) -> Union[LNPModel, LNPModelWithoutMPNN]:
|
||||
"""
|
||||
加载训练好的模型。
|
||||
|
||||
自动根据 checkpoint 的 config.use_mpnn 选择模型类型。
|
||||
"""
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
config = checkpoint["config"]
|
||||
use_mpnn = config.get("use_mpnn", False)
|
||||
|
||||
if use_mpnn:
|
||||
# 自动查找 MPNN ensemble
|
||||
logger.info("Model was trained with MPNN, auto-detecting ensemble...")
|
||||
ensemble_paths = find_mpnn_ensemble_paths()
|
||||
logger.info(f"Found {len(ensemble_paths)} MPNN models")
|
||||
|
||||
model = LNPModel(
|
||||
d_model=config["d_model"],
|
||||
num_heads=config["num_heads"],
|
||||
n_attn_layers=config["n_attn_layers"],
|
||||
fusion_strategy=config["fusion_strategy"],
|
||||
head_hidden_dim=config["head_hidden_dim"],
|
||||
dropout=config["dropout"],
|
||||
mpnn_ensemble_paths=ensemble_paths,
|
||||
mpnn_device=mpnn_device,
|
||||
)
|
||||
else:
|
||||
model = LNPModelWithoutMPNN(
|
||||
d_model=config["d_model"],
|
||||
num_heads=config["num_heads"],
|
||||
@ -30,6 +67,7 @@ def load_model(model_path: Path, device: torch.device) -> LNPModelWithoutMPNN:
|
||||
head_hidden_dim=config["head_hidden_dim"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -43,7 +81,7 @@ def load_model(model_path: Path, device: torch.device) -> LNPModelWithoutMPNN:
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_batch(
|
||||
model: LNPModelWithoutMPNN,
|
||||
model: Union[LNPModel, LNPModelWithoutMPNN],
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
) -> Dict[str, List]:
|
||||
|
||||
361
lnp_ml/modeling/pretrain.py
Normal file
361
lnp_ml/modeling/pretrain.py
Normal file
@ -0,0 +1,361 @@
|
||||
"""预训练脚本:在外部 LiON 数据上预训练 backbone + delivery head"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import typer
|
||||
|
||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
|
||||
|
||||
# MPNN ensemble 默认路径
|
||||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||
|
||||
|
||||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||
"""
|
||||
自动查找 MPNN ensemble 的 model.pt 文件。
|
||||
|
||||
在 base_dir 下查找所有 cv_*/fold_*/model_*/model.pt 文件。
|
||||
"""
|
||||
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||
if not model_paths:
|
||||
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||
return [str(p) for p in model_paths]
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
"""早停机制"""
|
||||
|
||||
def __init__(self, patience: int = 10, min_delta: float = 0.0):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = float("inf")
|
||||
|
||||
def __call__(self, val_loss: float) -> bool:
|
||||
if val_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
return False
|
||||
self.counter += 1
|
||||
return self.counter >= self.patience
|
||||
|
||||
|
||||
def warmup_cache(model: nn.Module, smiles_list: List[str], batch_size: int = 256) -> None:
|
||||
"""
|
||||
预热 RDKit 特征缓存,避免训练时计算阻塞。
|
||||
"""
|
||||
unique_smiles = list(set(smiles_list))
|
||||
logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...")
|
||||
|
||||
for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"):
|
||||
batch = unique_smiles[i:i + batch_size]
|
||||
model.rdkit_encoder(batch)
|
||||
|
||||
logger.success(f"Cache warmup complete. Cached {len(model.rdkit_encoder._cache)} SMILES.")
|
||||
|
||||
|
||||
def train_epoch_delivery(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
epoch: int = 0,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
单个 epoch 的预训练(仅 delivery 任务)。
|
||||
"""
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
n_samples = 0
|
||||
|
||||
pbar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
|
||||
for batch in pbar:
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
targets = batch["targets"]["delivery"].to(device)
|
||||
mask = batch["mask"]["delivery"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward: 只预测 delivery
|
||||
pred = model.forward_delivery(smiles, tabular) # [B, 1]
|
||||
pred = pred.squeeze(-1) # [B]
|
||||
|
||||
# 计算损失(仅对有效样本)
|
||||
if mask.any():
|
||||
loss = nn.functional.mse_loss(pred[mask], targets[mask])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item() * mask.sum().item()
|
||||
n_samples += mask.sum().item()
|
||||
|
||||
pbar.set_postfix({"loss": total_loss / max(n_samples, 1)})
|
||||
|
||||
avg_loss = total_loss / max(n_samples, 1)
|
||||
return {"loss": avg_loss, "n_samples": n_samples}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_delivery(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
验证(仅 delivery 任务)。
|
||||
"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
n_samples = 0
|
||||
|
||||
for batch in loader:
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
targets = batch["targets"]["delivery"].to(device)
|
||||
mask = batch["mask"]["delivery"].to(device)
|
||||
|
||||
pred = model.forward_delivery(smiles, tabular).squeeze(-1)
|
||||
|
||||
if mask.any():
|
||||
loss = nn.functional.mse_loss(pred[mask], targets[mask])
|
||||
total_loss += loss.item() * mask.sum().item()
|
||||
n_samples += mask.sum().item()
|
||||
|
||||
avg_loss = total_loss / max(n_samples, 1)
|
||||
return {"loss": avg_loss, "n_samples": n_samples}
|
||||
|
||||
|
||||
def pretrain(
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
model: nn.Module,
|
||||
device: torch.device,
|
||||
lr: float = 1e-4,
|
||||
weight_decay: float = 1e-5,
|
||||
epochs: int = 50,
|
||||
patience: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
预训练循环。
|
||||
|
||||
Returns:
|
||||
训练历史和最佳验证损失
|
||||
"""
|
||||
model = model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.5, patience=5, verbose=True
|
||||
)
|
||||
early_stopping = EarlyStopping(patience=patience)
|
||||
|
||||
history = {"train": [], "val": []}
|
||||
best_val_loss = float("inf")
|
||||
best_state = None
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Train
|
||||
train_metrics = train_epoch_delivery(model, train_loader, optimizer, device, epoch)
|
||||
|
||||
# Validate
|
||||
val_metrics = validate_delivery(model, val_loader, device)
|
||||
|
||||
# Log
|
||||
logger.info(
|
||||
f"Epoch {epoch + 1}/{epochs} | "
|
||||
f"Train Loss: {train_metrics['loss']:.4f} | "
|
||||
f"Val Loss: {val_metrics['loss']:.4f}"
|
||||
)
|
||||
|
||||
history["train"].append(train_metrics)
|
||||
history["val"].append(val_metrics)
|
||||
|
||||
# Learning rate scheduling
|
||||
scheduler.step(val_metrics["loss"])
|
||||
|
||||
# Save best model
|
||||
if val_metrics["loss"] < best_val_loss:
|
||||
best_val_loss = val_metrics["loss"]
|
||||
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||
logger.info(f" -> New best model (val_loss={best_val_loss:.4f})")
|
||||
|
||||
# Early stopping
|
||||
if early_stopping(val_metrics["loss"]):
|
||||
logger.info(f"Early stopping at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# Restore best model
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
return {
|
||||
"history": history,
|
||||
"best_val_loss": best_val_loss,
|
||||
}
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
# 数据路径(已处理的 parquet 文件)
|
||||
train_path: Path = PROCESSED_DATA_DIR / "train_pretrain.parquet",
|
||||
val_path: Path = PROCESSED_DATA_DIR / "val_pretrain.parquet",
|
||||
output_dir: Path = MODELS_DIR,
|
||||
# 模型参数
|
||||
d_model: int = 256,
|
||||
num_heads: int = 8,
|
||||
n_attn_layers: int = 4,
|
||||
fusion_strategy: str = "attention",
|
||||
head_hidden_dim: int = 128,
|
||||
dropout: float = 0.1,
|
||||
# MPNN 参数(可选)
|
||||
use_mpnn: bool = False, # 启用 MPNN,自动从默认路径加载 ensemble
|
||||
mpnn_checkpoint: Optional[str] = None,
|
||||
mpnn_ensemble_paths: Optional[str] = None, # 逗号分隔的路径列表
|
||||
mpnn_device: str = "cpu",
|
||||
# 训练参数
|
||||
batch_size: int = 64,
|
||||
lr: float = 1e-4,
|
||||
weight_decay: float = 1e-5,
|
||||
epochs: int = 50,
|
||||
patience: int = 10,
|
||||
# 设备
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
"""
|
||||
在外部 LiON 数据上预训练 backbone + delivery head。
|
||||
|
||||
需要先运行 `make data_pretrain` 生成 parquet 文件。
|
||||
|
||||
使用 --use-mpnn 启用 MPNN encoder(自动从 models/mpnn/all_amine_split_for_LiON 加载)。
|
||||
|
||||
产出:
|
||||
- models/pretrain_delivery.pt: 包含 backbone + delivery head 权重
|
||||
- models/pretrain_history.json: 训练历史
|
||||
"""
|
||||
logger.info(f"Using device: {device}")
|
||||
device_obj = torch.device(device)
|
||||
|
||||
# 加载已处理的 parquet 文件
|
||||
logger.info(f"Loading train data from {train_path}")
|
||||
train_df = pd.read_parquet(train_path)
|
||||
train_dataset = ExternalDeliveryDataset(train_df)
|
||||
|
||||
logger.info(f"Loading val data from {val_path}")
|
||||
val_df = pd.read_parquet(val_path)
|
||||
val_dataset = ExternalDeliveryDataset(val_df)
|
||||
|
||||
logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# 解析 MPNN 配置
|
||||
# 优先级:mpnn_checkpoint > mpnn_ensemble_paths > use_mpnn(自动查找)
|
||||
ensemble_paths_list = None
|
||||
if mpnn_ensemble_paths:
|
||||
ensemble_paths_list = mpnn_ensemble_paths.split(",")
|
||||
elif use_mpnn and mpnn_checkpoint is None:
|
||||
# --use-mpnn 但没有指定具体路径,自动查找
|
||||
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
|
||||
ensemble_paths_list = find_mpnn_ensemble_paths()
|
||||
logger.info(f"Found {len(ensemble_paths_list)} MPNN models")
|
||||
|
||||
enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None
|
||||
|
||||
# 创建模型
|
||||
logger.info(f"Creating model (use_mpnn={enable_mpnn})...")
|
||||
if enable_mpnn:
|
||||
model = LNPModel(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
mpnn_checkpoint=mpnn_checkpoint,
|
||||
mpnn_ensemble_paths=ensemble_paths_list,
|
||||
mpnn_device=mpnn_device,
|
||||
)
|
||||
else:
|
||||
model = LNPModelWithoutMPNN(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(f"Model parameters: {n_params:,}")
|
||||
|
||||
# 预热 RDKit 缓存(避免训练时阻塞)
|
||||
all_smiles = train_df["smiles"].tolist() + val_df["smiles"].tolist()
|
||||
warmup_cache(model, all_smiles, batch_size=256)
|
||||
|
||||
# 预训练
|
||||
logger.info("Starting pretraining on external data (delivery only)...")
|
||||
result = pretrain(
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
model=model,
|
||||
device=device_obj,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
epochs=epochs,
|
||||
patience=patience,
|
||||
)
|
||||
|
||||
# 保存预训练 checkpoint
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint_path = output_dir / "pretrain_delivery.pt"
|
||||
torch.save(
|
||||
{
|
||||
"model_state_dict": model.state_dict(),
|
||||
"backbone_state_dict": model.get_backbone_state_dict(),
|
||||
"delivery_head_state_dict": model.get_delivery_head_state_dict(),
|
||||
"config": {
|
||||
"d_model": d_model,
|
||||
"num_heads": num_heads,
|
||||
"n_attn_layers": n_attn_layers,
|
||||
"fusion_strategy": fusion_strategy,
|
||||
"head_hidden_dim": head_hidden_dim,
|
||||
"dropout": dropout,
|
||||
"use_mpnn": enable_mpnn,
|
||||
},
|
||||
"best_val_loss": result["best_val_loss"],
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
logger.success(f"Saved pretrain checkpoint to {checkpoint_path}")
|
||||
|
||||
# 保存训练历史
|
||||
history_path = output_dir / "pretrain_history.json"
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(result["history"], f, indent=2)
|
||||
logger.success(f"Saved pretrain history to {history_path}")
|
||||
|
||||
logger.success(
|
||||
f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -12,7 +12,7 @@ import typer
|
||||
|
||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||
from lnp_ml.modeling.models import LNPModelWithoutMPNN
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
from lnp_ml.modeling.trainer import (
|
||||
train_epoch,
|
||||
validate,
|
||||
@ -20,6 +20,21 @@ from lnp_ml.modeling.trainer import (
|
||||
LossWeights,
|
||||
)
|
||||
|
||||
# MPNN ensemble 默认路径
|
||||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||
|
||||
|
||||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||
"""
|
||||
自动查找 MPNN ensemble 的 model.pt 文件。
|
||||
|
||||
在 base_dir 下查找所有 cv_*/fold_*/model_*/model.pt 文件。
|
||||
"""
|
||||
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||
if not model_paths:
|
||||
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||
return [str(p) for p in model_paths]
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@ -31,8 +46,27 @@ def create_model(
|
||||
fusion_strategy: str = "attention",
|
||||
head_hidden_dim: int = 128,
|
||||
dropout: float = 0.1,
|
||||
) -> LNPModelWithoutMPNN:
|
||||
"""创建模型"""
|
||||
# MPNN 参数(可选)
|
||||
mpnn_checkpoint: Optional[str] = None,
|
||||
mpnn_ensemble_paths: Optional[List[str]] = None,
|
||||
mpnn_device: str = "cpu",
|
||||
) -> Union[LNPModel, LNPModelWithoutMPNN]:
|
||||
"""创建模型(支持可选的 MPNN encoder)"""
|
||||
use_mpnn = mpnn_checkpoint is not None or mpnn_ensemble_paths is not None
|
||||
|
||||
if use_mpnn:
|
||||
return LNPModel(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
mpnn_checkpoint=mpnn_checkpoint,
|
||||
mpnn_ensemble_paths=mpnn_ensemble_paths,
|
||||
mpnn_device=mpnn_device,
|
||||
)
|
||||
else:
|
||||
return LNPModelWithoutMPNN(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
@ -191,6 +225,11 @@ def main(
|
||||
fusion_strategy: str = "attention",
|
||||
head_hidden_dim: int = 128,
|
||||
dropout: float = 0.1,
|
||||
# MPNN 参数(可选)
|
||||
use_mpnn: bool = False, # 启用 MPNN,自动从默认路径加载 ensemble
|
||||
mpnn_checkpoint: Optional[str] = None,
|
||||
mpnn_ensemble_paths: Optional[str] = None, # 逗号分隔的路径列表
|
||||
mpnn_device: str = "cpu",
|
||||
# 训练参数
|
||||
batch_size: int = 32,
|
||||
lr: float = 1e-4,
|
||||
@ -201,13 +240,20 @@ def main(
|
||||
tune: bool = False,
|
||||
n_trials: int = 20,
|
||||
epochs_per_trial: int = 30,
|
||||
# 预训练权重加载
|
||||
init_from_pretrain: Optional[Path] = None,
|
||||
load_delivery_head: bool = True,
|
||||
freeze_backbone: bool = False, # 冻结 backbone,只训练 heads
|
||||
# 设备
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
"""
|
||||
训练 LNP 预测模型。
|
||||
训练 LNP 预测模型(多任务 finetune)。
|
||||
|
||||
使用 --tune 启用超参数调优。
|
||||
使用 --init-from-pretrain 从预训练 checkpoint 初始化 backbone。
|
||||
使用 --use-mpnn 启用 MPNN encoder(自动从 models/mpnn/all_amine_split_for_LiON 加载)。
|
||||
使用 --freeze-backbone 冻结 backbone,只训练多任务 heads。
|
||||
"""
|
||||
logger.info(f"Using device: {device}")
|
||||
device = torch.device(device)
|
||||
@ -258,8 +304,21 @@ def main(
|
||||
lr = best_params["lr"]
|
||||
weight_decay = best_params["weight_decay"]
|
||||
|
||||
# 解析 MPNN 配置
|
||||
# 优先级:mpnn_checkpoint > mpnn_ensemble_paths > use_mpnn(自动查找)
|
||||
ensemble_paths_list = None
|
||||
if mpnn_ensemble_paths:
|
||||
ensemble_paths_list = mpnn_ensemble_paths.split(",")
|
||||
elif use_mpnn and mpnn_checkpoint is None:
|
||||
# --use-mpnn 但没有指定具体路径,自动查找
|
||||
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
|
||||
ensemble_paths_list = find_mpnn_ensemble_paths()
|
||||
logger.info(f"Found {len(ensemble_paths_list)} MPNN models")
|
||||
|
||||
enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None
|
||||
|
||||
# 创建模型
|
||||
logger.info("Creating model...")
|
||||
logger.info(f"Creating model (use_mpnn={enable_mpnn})...")
|
||||
model = create_model(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
@ -267,11 +326,48 @@ def main(
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
mpnn_checkpoint=mpnn_checkpoint,
|
||||
mpnn_ensemble_paths=ensemble_paths_list,
|
||||
mpnn_device=mpnn_device,
|
||||
)
|
||||
|
||||
# 加载预训练权重(如果指定)
|
||||
if init_from_pretrain is not None:
|
||||
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
|
||||
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
|
||||
|
||||
# 检查配置是否兼容
|
||||
pretrain_config = checkpoint.get("config", {})
|
||||
if pretrain_config.get("d_model") != d_model:
|
||||
logger.warning(
|
||||
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
|
||||
f"current={d_model}. Skipping pretrain loading."
|
||||
)
|
||||
else:
|
||||
# 加载 backbone + (可选) delivery head
|
||||
model.load_pretrain_weights(
|
||||
pretrain_state_dict=checkpoint["model_state_dict"],
|
||||
load_delivery_head=load_delivery_head,
|
||||
strict=False,
|
||||
)
|
||||
logger.success(
|
||||
f"Loaded pretrain weights (backbone + delivery_head={load_delivery_head})"
|
||||
)
|
||||
|
||||
# 冻结 backbone(如果指定)
|
||||
if freeze_backbone:
|
||||
logger.info("Freezing backbone (token_projector, cross_attention, fusion)...")
|
||||
frozen_count = 0
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
|
||||
param.requires_grad = False
|
||||
frozen_count += 1
|
||||
logger.info(f"Frozen {frozen_count} parameter tensors")
|
||||
|
||||
# 打印模型信息
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(f"Model parameters: {n_params:,}")
|
||||
n_params_total = sum(p.numel() for p in model.parameters())
|
||||
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
|
||||
|
||||
# 训练
|
||||
logger.info("Starting training...")
|
||||
@ -297,8 +393,10 @@ def main(
|
||||
"fusion_strategy": fusion_strategy,
|
||||
"head_hidden_dim": head_hidden_dim,
|
||||
"dropout": dropout,
|
||||
"use_mpnn": enable_mpnn,
|
||||
},
|
||||
"best_val_loss": result["best_val_loss"],
|
||||
"init_from_pretrain": str(init_from_pretrain) if init_from_pretrain else None,
|
||||
}, model_path)
|
||||
logger.success(f"Saved model to {model_path}")
|
||||
|
||||
|
||||
2446
models/history.json
2446
models/history.json
File diff suppressed because it is too large
Load Diff
BIN
models/model.pt
BIN
models/model.pt
Binary file not shown.
BIN
models/pretrain_delivery.pt
Normal file
BIN
models/pretrain_delivery.pt
Normal file
Binary file not shown.
286
models/pretrain_history.json
Normal file
286
models/pretrain_history.json
Normal file
@ -0,0 +1,286 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 0.8151603855680482,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.6867990841792819,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.645540308540888,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.5923176541020599,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.5720762926262872,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.5477570670417328,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.5280393017717573,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.5122504676513313,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.49667307051028314,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.486139440352648,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.4749755339466122,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.4636757543530298,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.4543497681877452,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.4408158337956461,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.4419790126221837,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.42850686623585116,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.41607048387867007,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.427172136486513,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.4125568530569382,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.39480836287767923,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3885056775666858,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3894976457588827,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3890058272899995,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3741690826284791,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3534914434345719,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3349389765134386,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.32965143874976194,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.32094062546116675,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.32526135008251184,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.31289531808423826,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3088379208288558,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.2994744991261045,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.2981521815160671,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.29143649446979303,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.29075756723379653,
|
||||
"n_samples": 8721
|
||||
}
|
||||
],
|
||||
"val": [
|
||||
{
|
||||
"loss": 0.7625711447683281,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.7092331695236781,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.7014068689723995,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6595172673863646,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6312279044905191,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6349272860831151,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6587623598744133,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6093261837651732,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6125607111474924,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6005943137518024,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6876292386783289,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5940848466228036,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5820883587079644,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6302792748938035,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5849901610914275,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5830434826428553,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5643168952858116,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5592790719340829,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.600335100686833,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5646457721097674,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.6288956836004376,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5771863222183704,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5738056593250687,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5636531712085593,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5465074849879163,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5701294839843508,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5570075802420438,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5711473401701241,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5576858864741552,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5624132422716871,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5655298555506272,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5568078993151677,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.567752199383958,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5683093779442603,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5741443767974497,
|
||||
"n_samples": 969
|
||||
}
|
||||
]
|
||||
}
|
||||
98
scripts/process_external.py
Normal file
98
scripts/process_external.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""外部数据预处理脚本:external -> processed"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import typer
|
||||
from loguru import logger
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
|
||||
from lnp_ml.dataset import process_external_dataframe, LNPDatasetConfig, get_phys_cols, get_exp_cols, COMP_COLS, HELP_COLS
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_path: Path = EXTERNAL_DATA_DIR / "all_data_LiON.csv",
|
||||
output_dir: Path = PROCESSED_DATA_DIR,
|
||||
train_ratio: float = 0.9,
|
||||
seed: int = 42,
|
||||
):
|
||||
"""
|
||||
处理外部 LiON 数据,生成预训练用的 parquet 文件。
|
||||
|
||||
输出:
|
||||
- processed/train_pretrain.parquet
|
||||
- processed/val_pretrain.parquet
|
||||
- processed/feature_columns_pretrain.txt
|
||||
"""
|
||||
logger.info(f"Loading external data from {input_path}")
|
||||
df = pd.read_csv(input_path)
|
||||
logger.info(f"Loaded {len(df)} samples")
|
||||
|
||||
# 过滤掉 quantified_delivery 为空的行
|
||||
if "quantified_delivery" in df.columns:
|
||||
before_len = len(df)
|
||||
df = df[df["quantified_delivery"].notna()].reset_index(drop=True)
|
||||
logger.info(f"Filtered NaN delivery: {before_len} -> {len(df)} samples")
|
||||
|
||||
# 处理数据(列对齐、one-hot 生成)
|
||||
logger.info("Processing dataframe (column alignment, one-hot encoding)...")
|
||||
df = process_external_dataframe(df)
|
||||
|
||||
# 获取所需列
|
||||
config = LNPDatasetConfig()
|
||||
feature_cols = (
|
||||
["smiles"]
|
||||
+ config.comp_cols
|
||||
+ config.phys_cols
|
||||
+ config.help_cols
|
||||
+ config.exp_cols
|
||||
+ ["quantified_delivery"]
|
||||
)
|
||||
|
||||
# 只保留需要的列
|
||||
available_cols = [c for c in feature_cols if c in df.columns]
|
||||
missing_cols = [c for c in feature_cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
logger.warning(f"Missing columns (will be filled with 0): {missing_cols}")
|
||||
for col in missing_cols:
|
||||
df[col] = 0.0
|
||||
|
||||
df = df[feature_cols]
|
||||
|
||||
# 划分 train/val
|
||||
logger.info(f"Splitting data: train_ratio={train_ratio}, seed={seed}")
|
||||
train_df, val_df = train_test_split(
|
||||
df, train_size=train_ratio, random_state=seed, shuffle=True
|
||||
)
|
||||
train_df = train_df.reset_index(drop=True)
|
||||
val_df = val_df.reset_index(drop=True)
|
||||
|
||||
logger.info(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}")
|
||||
|
||||
# 保存
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_path = output_dir / "train_pretrain.parquet"
|
||||
val_path = output_dir / "val_pretrain.parquet"
|
||||
|
||||
train_df.to_parquet(train_path, index=False)
|
||||
val_df.to_parquet(val_path, index=False)
|
||||
|
||||
logger.success(f"Saved train data to {train_path}")
|
||||
logger.success(f"Saved val data to {val_path}")
|
||||
|
||||
# 保存特征列配置
|
||||
cols_path = output_dir / "feature_columns_pretrain.txt"
|
||||
with open(cols_path, "w") as f:
|
||||
f.write("\n".join(feature_cols))
|
||||
logger.success(f"Saved feature columns to {cols_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user