mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-01-19 11:53:13 +08:00
...
This commit is contained in:
parent
96a27caab2
commit
0290649df1
29
Makefile
29
Makefile
@ -68,15 +68,38 @@ clean_data: requirements
|
|||||||
data: requirements
|
data: requirements
|
||||||
$(PYTHON_INTERPRETER) scripts/process_data.py
|
$(PYTHON_INTERPRETER) scripts/process_data.py
|
||||||
|
|
||||||
## Train model
|
## Process external data for pretrain (external -> processed)
|
||||||
|
.PHONY: data_pretrain
|
||||||
|
data_pretrain: requirements
|
||||||
|
$(PYTHON_INTERPRETER) scripts/process_external.py
|
||||||
|
|
||||||
|
# MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder
|
||||||
|
# 例如:make pretrain USE_MPNN=1
|
||||||
|
MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,)
|
||||||
|
|
||||||
|
# Backbone 冻结:使用 FREEZE_BACKBONE=1 冻结 backbone,只训练 heads
|
||||||
|
# 例如:make finetune FREEZE_BACKBONE=1
|
||||||
|
FREEZE_FLAG = $(if $(FREEZE_BACKBONE),--freeze-backbone,)
|
||||||
|
|
||||||
|
## Pretrain on external data (delivery only)
|
||||||
|
.PHONY: pretrain
|
||||||
|
pretrain: requirements
|
||||||
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain $(MPNN_FLAG)
|
||||||
|
|
||||||
|
## Train model (multi-task, from scratch)
|
||||||
.PHONY: train
|
.PHONY: train
|
||||||
train: requirements
|
train: requirements
|
||||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train $(MPNN_FLAG)
|
||||||
|
|
||||||
|
## Finetune from pretrained checkpoint (use FREEZE_BACKBONE=1 to freeze backbone)
|
||||||
|
.PHONY: finetune
|
||||||
|
finetune: requirements
|
||||||
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG)
|
||||||
|
|
||||||
## Train with hyperparameter tuning
|
## Train with hyperparameter tuning
|
||||||
.PHONY: tune
|
.PHONY: tune
|
||||||
tune: requirements
|
tune: requirements
|
||||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG)
|
||||||
|
|
||||||
## Run predictions
|
## Run predictions
|
||||||
.PHONY: predict
|
.PHONY: predict
|
||||||
|
|||||||
101
README.md
101
README.md
@ -4,7 +4,106 @@
|
|||||||
<img src="https://img.shields.io/badge/CCDS-Project%20template-328F97?logo=cookiecutter" />
|
<img src="https://img.shields.io/badge/CCDS-Project%20template-328F97?logo=cookiecutter" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
A short description of the project.
|
LNP(脂质纳米颗粒)药物递送性能预测模型。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 安装环境
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pixi install
|
||||||
|
pixi shell
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 数据处理
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 清洗原始数据 (raw -> interim)
|
||||||
|
make clean_data
|
||||||
|
|
||||||
|
# 处理内部数据集 (interim -> processed)
|
||||||
|
make data
|
||||||
|
|
||||||
|
# 处理外部预训练数据 (external -> processed)
|
||||||
|
make data_pretrain
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 训练模型
|
||||||
|
|
||||||
|
**方式 A:直接训练(从零开始)**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make train
|
||||||
|
```
|
||||||
|
|
||||||
|
**方式 B:预训练 + 微调(推荐)**
|
||||||
|
|
||||||
|
利用外部 LiON 数据集(约 9000 条)进行预训练,再在内部数据上微调:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Step 1: 处理外部数据
|
||||||
|
make data_pretrain
|
||||||
|
|
||||||
|
# Step 2: 在外部数据上预训练 backbone + delivery head
|
||||||
|
make pretrain
|
||||||
|
|
||||||
|
# Step 3: 加载预训练权重,在内部数据上多任务微调
|
||||||
|
make finetune
|
||||||
|
```
|
||||||
|
|
||||||
|
**方式 C:超参数调优**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make tune
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 测试与预测
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 在测试集上评估
|
||||||
|
make test
|
||||||
|
|
||||||
|
# 生成预测结果
|
||||||
|
make predict
|
||||||
|
```
|
||||||
|
|
||||||
|
## 训练流程详解
|
||||||
|
|
||||||
|
### 预训练 (Pretrain)
|
||||||
|
|
||||||
|
在外部 LiON 数据上,仅训练 `quantified_delivery` 任务:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 先处理外部数据
|
||||||
|
python scripts/process_external.py
|
||||||
|
|
||||||
|
# 2. 预训练
|
||||||
|
python -m lnp_ml.modeling.pretrain \
|
||||||
|
--train-path data/processed/train_pretrain.parquet \
|
||||||
|
--val-path data/processed/val_pretrain.parquet \
|
||||||
|
--epochs 50 \
|
||||||
|
--lr 1e-4
|
||||||
|
```
|
||||||
|
|
||||||
|
产出:
|
||||||
|
- `data/processed/train_pretrain.parquet`: 处理后的训练数据
|
||||||
|
- `data/processed/val_pretrain.parquet`: 处理后的验证数据
|
||||||
|
- `models/pretrain_delivery.pt`: backbone + delivery head 权重
|
||||||
|
- `models/pretrain_history.json`: 训练历史
|
||||||
|
|
||||||
|
### 微调 (Finetune)
|
||||||
|
|
||||||
|
加载预训练权重,在内部多任务数据上训练:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m lnp_ml.modeling.train \
|
||||||
|
--init-from-pretrain models/pretrain_delivery.pt \
|
||||||
|
--load-delivery-head # 可选:是否加载 delivery head 权重
|
||||||
|
```
|
||||||
|
|
||||||
|
产出:
|
||||||
|
- `models/model.pt`: 完整模型权重
|
||||||
|
- `models/history.json`: 训练历史
|
||||||
|
|
||||||
## Project Organization
|
## Project Organization
|
||||||
|
|
||||||
|
|||||||
@ -346,3 +346,185 @@ def load_dataset(
|
|||||||
LNPDataset(val_df, config),
|
LNPDataset(val_df, config),
|
||||||
LNPDataset(test_df, config),
|
LNPDataset(test_df, config),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 外部数据(仅 delivery)处理 ============
|
||||||
|
|
||||||
|
# 外部数据中 Value_name 的值映射(空格 -> 下划线)
|
||||||
|
EXTERNAL_VALUE_NAME_MAP = {
|
||||||
|
"log luminescence": "log_luminescence",
|
||||||
|
"Luminescence": "luminescence",
|
||||||
|
"FFL silencing": "FFL_silencing",
|
||||||
|
"Peptide abundance": "Peptide_abundance",
|
||||||
|
"hEPO": "hEPO",
|
||||||
|
"FVII silencing": "FVII_silencing",
|
||||||
|
"GFP delivery": "GFP_delivery",
|
||||||
|
"Discretized luminescence": "Discretized_luminescence",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 外部数据中 Mix_type 的值映射
|
||||||
|
EXTERNAL_MIX_TYPE_MAP = {
|
||||||
|
"Hand": "Pipetting", # 外部数据用 "Hand" 表示 "Pipetting"
|
||||||
|
"Microfluidic": "Microfluidic",
|
||||||
|
"Pipetting": "Pipetting",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def process_external_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
处理外部 LiON 数据的 DataFrame,对齐到模型所需的列格式。
|
||||||
|
|
||||||
|
与 process_dataframe 类似,但针对外部数据的列名差异做适配:
|
||||||
|
- Value_name 值中的空格需要转为下划线
|
||||||
|
- Mix_type 中 "Hand" 需要映射为 "Pipetting"
|
||||||
|
"""
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# 1. 预处理:映射 Value_name 和 Mix_type 的值
|
||||||
|
if "Value_name" in df.columns:
|
||||||
|
df["Value_name"] = df["Value_name"].map(
|
||||||
|
lambda x: EXTERNAL_VALUE_NAME_MAP.get(x, x) if pd.notna(x) else x
|
||||||
|
)
|
||||||
|
|
||||||
|
if "Mix_type" in df.columns:
|
||||||
|
df["Mix_type"] = df["Mix_type"].map(
|
||||||
|
lambda x: EXTERNAL_MIX_TYPE_MAP.get(x, x) if pd.notna(x) else x
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 处理 phys token 的 one-hot 列(如果不存在则从原始列生成)
|
||||||
|
for col, values in PHYS_ONEHOT_SPECS.items():
|
||||||
|
for v in values:
|
||||||
|
onehot_col = f"{col}_{v}"
|
||||||
|
if onehot_col not in df.columns:
|
||||||
|
if col in df.columns:
|
||||||
|
df[onehot_col] = (df[col] == v).astype(float)
|
||||||
|
else:
|
||||||
|
df[onehot_col] = 0.0
|
||||||
|
|
||||||
|
# 3. 处理 exp token 的 one-hot 列
|
||||||
|
# 外部数据部分列已存在(如 Model_type_*, Delivery_target_* 等),但可能缺失某些类别
|
||||||
|
for col, values in EXP_ONEHOT_SPECS.items():
|
||||||
|
for v in values:
|
||||||
|
onehot_col = f"{col}_{v}"
|
||||||
|
if onehot_col not in df.columns:
|
||||||
|
if col in df.columns:
|
||||||
|
df[onehot_col] = (df[col] == v).astype(float)
|
||||||
|
else:
|
||||||
|
df[onehot_col] = 0.0
|
||||||
|
else:
|
||||||
|
# 确保是 float 类型
|
||||||
|
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
|
||||||
|
|
||||||
|
# 4. 确保 comp 列存在且为 float
|
||||||
|
for col in COMP_COLS:
|
||||||
|
if col in df.columns:
|
||||||
|
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
||||||
|
else:
|
||||||
|
df[col] = 0.0
|
||||||
|
|
||||||
|
# 5. 确保 help 列存在
|
||||||
|
for col in HELP_COLS:
|
||||||
|
if col not in df.columns:
|
||||||
|
df[col] = 0.0
|
||||||
|
else:
|
||||||
|
df[col] = df[col].fillna(0.0).astype(float)
|
||||||
|
|
||||||
|
# 6. 处理 quantified_delivery
|
||||||
|
if "quantified_delivery" in df.columns:
|
||||||
|
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalDeliveryDataset(Dataset):
|
||||||
|
"""
|
||||||
|
外部 LiON 数据集,仅用于 delivery 预训练。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- smiles: str
|
||||||
|
- tabular: Dict[str, Tensor] with keys "comp", "phys", "help", "exp"
|
||||||
|
- targets: Dict[str, Tensor] with key "delivery"
|
||||||
|
- mask: Dict[str, Tensor] with key "delivery"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
config: Optional[LNPDatasetConfig] = None,
|
||||||
|
):
|
||||||
|
self.config = config or LNPDatasetConfig()
|
||||||
|
self.df = process_external_dataframe(df)
|
||||||
|
|
||||||
|
# 提取数据
|
||||||
|
self.smiles = self.df[SMILES_COL].tolist()
|
||||||
|
|
||||||
|
# Tabular features
|
||||||
|
self.comp = self.df[self.config.comp_cols].values.astype(np.float32)
|
||||||
|
self.phys = self.df[self.config.phys_cols].values.astype(np.float32)
|
||||||
|
self.help = self.df[self.config.help_cols].values.astype(np.float32)
|
||||||
|
self.exp = self.df[self.config.exp_cols].values.astype(np.float32)
|
||||||
|
|
||||||
|
# 只有 delivery 作为 target
|
||||||
|
self.delivery = self.df["quantified_delivery"].values.astype(np.float32) if "quantified_delivery" in self.df.columns else None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.smiles)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict:
|
||||||
|
item = {
|
||||||
|
"smiles": self.smiles[idx],
|
||||||
|
"tabular": {
|
||||||
|
"comp": torch.from_numpy(self.comp[idx]),
|
||||||
|
"phys": torch.from_numpy(self.phys[idx]),
|
||||||
|
"help": torch.from_numpy(self.help[idx]),
|
||||||
|
"exp": torch.from_numpy(self.exp[idx]),
|
||||||
|
},
|
||||||
|
"targets": {},
|
||||||
|
"mask": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# delivery target and mask
|
||||||
|
if self.delivery is not None:
|
||||||
|
item["targets"]["delivery"] = torch.tensor(self.delivery[idx], dtype=torch.float32)
|
||||||
|
item["mask"]["delivery"] = torch.tensor(not np.isnan(self.delivery[idx]), dtype=torch.bool)
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def load_external_dataset(
|
||||||
|
path: Path,
|
||||||
|
train_ratio: float = 0.9,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> Tuple[ExternalDeliveryDataset, ExternalDeliveryDataset]:
|
||||||
|
"""
|
||||||
|
加载外部 LiON 数据集并划分为 train/val。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: CSV 文件路径
|
||||||
|
train_ratio: 训练集比例(剩余为验证集)
|
||||||
|
seed: 随机种子
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(train_dataset, val_dataset)
|
||||||
|
"""
|
||||||
|
df = pd.read_csv(path)
|
||||||
|
|
||||||
|
# 过滤掉 quantified_delivery 为空的行
|
||||||
|
if "quantified_delivery" in df.columns:
|
||||||
|
df = df[df["quantified_delivery"].notna()].reset_index(drop=True)
|
||||||
|
|
||||||
|
# 随机打乱
|
||||||
|
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
|
||||||
|
|
||||||
|
n = len(df)
|
||||||
|
n_train = int(n * train_ratio)
|
||||||
|
|
||||||
|
train_df = df.iloc[:n_train]
|
||||||
|
val_df = df.iloc[n_train:]
|
||||||
|
|
||||||
|
config = LNPDatasetConfig()
|
||||||
|
|
||||||
|
return (
|
||||||
|
ExternalDeliveryDataset(train_df, config),
|
||||||
|
ExternalDeliveryDataset(val_df, config),
|
||||||
|
)
|
||||||
|
|||||||
@ -43,7 +43,13 @@ class RDKitFeaturizer:
|
|||||||
return np.array(MACCSkeys.GenMACCSKeys(mol).ToList(), dtype=np.float32)
|
return np.array(MACCSkeys.GenMACCSKeys(mol).ToList(), dtype=np.float32)
|
||||||
|
|
||||||
def _encode_desc(self, mol: Chem.Mol) -> np.ndarray:
|
def _encode_desc(self, mol: Chem.Mol) -> np.ndarray:
|
||||||
return np.array(list(Descriptors.CalcMolDescriptors(mol).values()), dtype=np.float32)
|
# 使用 float64 计算,然后 clip 到 float32 范围,避免 overflow
|
||||||
|
desc_values = list(Descriptors.CalcMolDescriptors(mol).values())
|
||||||
|
arr = np.array(desc_values, dtype=np.float64)
|
||||||
|
# 替换 inf/nan,clip 到 float32 范围
|
||||||
|
arr = np.nan_to_num(arr, nan=0.0, posinf=1e10, neginf=-1e10)
|
||||||
|
arr = np.clip(arr, -1e10, 1e10)
|
||||||
|
return arr.astype(np.float32)
|
||||||
|
|
||||||
def _encode_one(self, smiles: str) -> Dict[str, np.ndarray]:
|
def _encode_one(self, smiles: str) -> Dict[str, np.ndarray]:
|
||||||
mol = Chem.MolFromSmiles(smiles)
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
|||||||
@ -121,28 +121,16 @@ class LNPModel(nn.Module):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def _encode_and_project(
|
||||||
self,
|
self,
|
||||||
smiles: List[str],
|
smiles: List[str],
|
||||||
tabular: Dict[str, torch.Tensor],
|
tabular: Dict[str, torch.Tensor],
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
内部方法:编码 SMILES 和 tabular,返回 stacked tokens。
|
||||||
smiles: SMILES 字符串列表,长度为 B
|
|
||||||
tabular: Dict[str, Tensor],包含:
|
|
||||||
- "comp": [B, 5] 配方比例
|
|
||||||
- "phys": [B, 12] 物理参数
|
|
||||||
- "help": [B, 4] Helper lipid
|
|
||||||
- "exp": [B, 32] 实验条件
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Tensor]:
|
stacked: [B, n_tokens, d_model]
|
||||||
- "size": [B, 1]
|
|
||||||
- "pdi": [B, 4]
|
|
||||||
- "ee": [B, 3]
|
|
||||||
- "delivery": [B, 1]
|
|
||||||
- "biodist": [B, 7]
|
|
||||||
- "toxic": [B, 2]
|
|
||||||
"""
|
"""
|
||||||
# 1. Encode SMILES
|
# 1. Encode SMILES
|
||||||
rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"}
|
rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"}
|
||||||
@ -170,23 +158,87 @@ class LNPModel(nn.Module):
|
|||||||
projected = self.token_projector(all_features) # Dict[str, [B, d_model]]
|
projected = self.token_projector(all_features) # Dict[str, [B, d_model]]
|
||||||
|
|
||||||
# 4. Stack tokens: [B, n_tokens, d_model]
|
# 4. Stack tokens: [B, n_tokens, d_model]
|
||||||
# 按顺序 stack:Channel A (化学) + Channel B (配方/实验)
|
|
||||||
if self.use_mpnn:
|
if self.use_mpnn:
|
||||||
token_order = ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
token_order = ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||||
else:
|
else:
|
||||||
token_order = ["morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
token_order = ["morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||||
|
|
||||||
stacked = torch.stack([projected[k] for k in token_order], dim=1)
|
stacked = torch.stack([projected[k] for k in token_order], dim=1)
|
||||||
|
return stacked
|
||||||
|
|
||||||
# 5. Cross Modal Attention
|
def forward_backbone(
|
||||||
|
self,
|
||||||
|
smiles: List[str],
|
||||||
|
tabular: Dict[str, torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Backbone forward:编码 -> 投影 -> 注意力 -> 融合,不经过任务头。
|
||||||
|
|
||||||
|
用于 pretrain 阶段或需要提取特征的场景。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smiles: SMILES 字符串列表,长度为 B
|
||||||
|
tabular: Dict[str, Tensor]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fused: [B, fusion_dim] 融合后的特征向量
|
||||||
|
"""
|
||||||
|
# 编码 + 投影 + stack
|
||||||
|
stacked = self._encode_and_project(smiles, tabular)
|
||||||
|
|
||||||
|
# Cross Modal Attention
|
||||||
attended = self.cross_attention(stacked)
|
attended = self.cross_attention(stacked)
|
||||||
|
|
||||||
# 6. Fusion
|
# Fusion
|
||||||
fused = self.fusion(attended)
|
fused = self.fusion(attended)
|
||||||
|
|
||||||
|
return fused
|
||||||
|
|
||||||
# 7. Multi-Task Head
|
def forward_delivery(
|
||||||
|
self,
|
||||||
|
smiles: List[str],
|
||||||
|
tabular: Dict[str, torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
仅预测 delivery(用于 pretrain)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smiles: SMILES 字符串列表,长度为 B
|
||||||
|
tabular: Dict[str, Tensor]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
delivery: [B, 1] 预测的 delivery 值
|
||||||
|
"""
|
||||||
|
fused = self.forward_backbone(smiles, tabular)
|
||||||
|
return self.head.delivery_head(fused)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
smiles: List[str],
|
||||||
|
tabular: Dict[str, torch.Tensor],
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
完整的多任务 forward。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smiles: SMILES 字符串列表,长度为 B
|
||||||
|
tabular: Dict[str, Tensor],包含:
|
||||||
|
- "comp": [B, 5] 配方比例
|
||||||
|
- "phys": [B, 12] 物理参数
|
||||||
|
- "help": [B, 4] Helper lipid
|
||||||
|
- "exp": [B, 32] 实验条件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Tensor]:
|
||||||
|
- "size": [B, 1]
|
||||||
|
- "pdi": [B, 4]
|
||||||
|
- "ee": [B, 3]
|
||||||
|
- "delivery": [B, 1]
|
||||||
|
- "biodist": [B, 7]
|
||||||
|
- "toxic": [B, 2]
|
||||||
|
"""
|
||||||
|
fused = self.forward_backbone(smiles, tabular)
|
||||||
outputs = self.head(fused)
|
outputs = self.head(fused)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
@ -195,6 +247,69 @@ class LNPModel(nn.Module):
|
|||||||
if self.mpnn_encoder is not None:
|
if self.mpnn_encoder is not None:
|
||||||
self.mpnn_encoder.clear_cache()
|
self.mpnn_encoder.clear_cache()
|
||||||
|
|
||||||
|
def get_backbone_state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
获取 backbone 部分的 state_dict(不含任务头)。
|
||||||
|
|
||||||
|
包含: token_projector, cross_attention, fusion
|
||||||
|
"""
|
||||||
|
backbone_keys = []
|
||||||
|
for name in self.state_dict().keys():
|
||||||
|
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
|
||||||
|
backbone_keys.append(name)
|
||||||
|
|
||||||
|
return {k: v for k, v in self.state_dict().items() if k in backbone_keys}
|
||||||
|
|
||||||
|
def get_delivery_head_state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
|
"""获取 delivery head 的 state_dict"""
|
||||||
|
return {
|
||||||
|
k: v for k, v in self.state_dict().items()
|
||||||
|
if k.startswith("head.delivery_head.")
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_pretrain_weights(
|
||||||
|
self,
|
||||||
|
pretrain_state_dict: Dict[str, torch.Tensor],
|
||||||
|
load_delivery_head: bool = True,
|
||||||
|
strict: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
从预训练 checkpoint 加载 backbone 和(可选)delivery head 权重。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrain_state_dict: 预训练模型的 state_dict
|
||||||
|
load_delivery_head: 是否加载 delivery head 权重
|
||||||
|
strict: 是否严格匹配(默认 False,允许缺失/多余的键)
|
||||||
|
"""
|
||||||
|
# 筛选要加载的参数
|
||||||
|
keys_to_load = []
|
||||||
|
for name in pretrain_state_dict.keys():
|
||||||
|
# Backbone 部分
|
||||||
|
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
|
||||||
|
keys_to_load.append(name)
|
||||||
|
# Delivery head(可选)
|
||||||
|
elif load_delivery_head and name.startswith("head.delivery_head."):
|
||||||
|
keys_to_load.append(name)
|
||||||
|
|
||||||
|
filtered_state_dict = {k: v for k, v in pretrain_state_dict.items() if k in keys_to_load}
|
||||||
|
|
||||||
|
# 加载权重
|
||||||
|
missing, unexpected = [], []
|
||||||
|
model_state = self.state_dict()
|
||||||
|
for k, v in filtered_state_dict.items():
|
||||||
|
if k in model_state:
|
||||||
|
if model_state[k].shape == v.shape:
|
||||||
|
model_state[k] = v
|
||||||
|
else:
|
||||||
|
unexpected.append(f"{k} (shape mismatch: {model_state[k].shape} vs {v.shape})")
|
||||||
|
else:
|
||||||
|
unexpected.append(k)
|
||||||
|
|
||||||
|
self.load_state_dict(model_state, strict=False)
|
||||||
|
|
||||||
|
if strict and (missing or unexpected):
|
||||||
|
raise RuntimeError(f"Missing keys: {missing}, Unexpected keys: {unexpected}")
|
||||||
|
|
||||||
|
|
||||||
class LNPModelWithoutMPNN(LNPModel):
|
class LNPModelWithoutMPNN(LNPModel):
|
||||||
"""不使用 MPNN 的简化版本"""
|
"""不使用 MPNN 的简化版本"""
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""预测脚本:使用训练好的模型进行推理"""
|
"""预测脚本:使用训练好的模型进行推理"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -11,25 +11,63 @@ import typer
|
|||||||
|
|
||||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||||
from lnp_ml.dataset import LNPDataset, collate_fn
|
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||||
from lnp_ml.modeling.models import LNPModelWithoutMPNN
|
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
# MPNN ensemble 默认路径
|
||||||
|
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||||
|
|
||||||
def load_model(model_path: Path, device: torch.device) -> LNPModelWithoutMPNN:
|
|
||||||
"""加载训练好的模型"""
|
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||||
|
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
|
||||||
|
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||||
|
if not model_paths:
|
||||||
|
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||||
|
return [str(p) for p in model_paths]
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model_path: Path,
|
||||||
|
device: torch.device,
|
||||||
|
mpnn_device: str = "cpu",
|
||||||
|
) -> Union[LNPModel, LNPModelWithoutMPNN]:
|
||||||
|
"""
|
||||||
|
加载训练好的模型。
|
||||||
|
|
||||||
|
自动根据 checkpoint 的 config.use_mpnn 选择模型类型。
|
||||||
|
"""
|
||||||
checkpoint = torch.load(model_path, map_location=device)
|
checkpoint = torch.load(model_path, map_location=device)
|
||||||
config = checkpoint["config"]
|
config = checkpoint["config"]
|
||||||
|
use_mpnn = config.get("use_mpnn", False)
|
||||||
|
|
||||||
|
if use_mpnn:
|
||||||
|
# 自动查找 MPNN ensemble
|
||||||
|
logger.info("Model was trained with MPNN, auto-detecting ensemble...")
|
||||||
|
ensemble_paths = find_mpnn_ensemble_paths()
|
||||||
|
logger.info(f"Found {len(ensemble_paths)} MPNN models")
|
||||||
|
|
||||||
|
model = LNPModel(
|
||||||
|
d_model=config["d_model"],
|
||||||
|
num_heads=config["num_heads"],
|
||||||
|
n_attn_layers=config["n_attn_layers"],
|
||||||
|
fusion_strategy=config["fusion_strategy"],
|
||||||
|
head_hidden_dim=config["head_hidden_dim"],
|
||||||
|
dropout=config["dropout"],
|
||||||
|
mpnn_ensemble_paths=ensemble_paths,
|
||||||
|
mpnn_device=mpnn_device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = LNPModelWithoutMPNN(
|
||||||
|
d_model=config["d_model"],
|
||||||
|
num_heads=config["num_heads"],
|
||||||
|
n_attn_layers=config["n_attn_layers"],
|
||||||
|
fusion_strategy=config["fusion_strategy"],
|
||||||
|
head_hidden_dim=config["head_hidden_dim"],
|
||||||
|
dropout=config["dropout"],
|
||||||
|
)
|
||||||
|
|
||||||
model = LNPModelWithoutMPNN(
|
|
||||||
d_model=config["d_model"],
|
|
||||||
num_heads=config["num_heads"],
|
|
||||||
n_attn_layers=config["n_attn_layers"],
|
|
||||||
fusion_strategy=config["fusion_strategy"],
|
|
||||||
head_hidden_dim=config["head_hidden_dim"],
|
|
||||||
dropout=config["dropout"],
|
|
||||||
)
|
|
||||||
model.load_state_dict(checkpoint["model_state_dict"])
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -43,7 +81,7 @@ def load_model(model_path: Path, device: torch.device) -> LNPModelWithoutMPNN:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_batch(
|
def predict_batch(
|
||||||
model: LNPModelWithoutMPNN,
|
model: Union[LNPModel, LNPModelWithoutMPNN],
|
||||||
loader: DataLoader,
|
loader: DataLoader,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
|
|||||||
361
lnp_ml/modeling/pretrain.py
Normal file
361
lnp_ml/modeling/pretrain.py
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
"""预训练脚本:在外部 LiON 数据上预训练 backbone + delivery head"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from loguru import logger
|
||||||
|
from tqdm import tqdm
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||||
|
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
|
||||||
|
|
||||||
|
# MPNN ensemble 默认路径
|
||||||
|
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||||
|
|
||||||
|
|
||||||
|
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||||
|
"""
|
||||||
|
自动查找 MPNN ensemble 的 model.pt 文件。
|
||||||
|
|
||||||
|
在 base_dir 下查找所有 cv_*/fold_*/model_*/model.pt 文件。
|
||||||
|
"""
|
||||||
|
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||||
|
if not model_paths:
|
||||||
|
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||||
|
return [str(p) for p in model_paths]
|
||||||
|
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
"""早停机制"""
|
||||||
|
|
||||||
|
def __init__(self, patience: int = 10, min_delta: float = 0.0):
|
||||||
|
self.patience = patience
|
||||||
|
self.min_delta = min_delta
|
||||||
|
self.counter = 0
|
||||||
|
self.best_loss = float("inf")
|
||||||
|
|
||||||
|
def __call__(self, val_loss: float) -> bool:
|
||||||
|
if val_loss < self.best_loss - self.min_delta:
|
||||||
|
self.best_loss = val_loss
|
||||||
|
self.counter = 0
|
||||||
|
return False
|
||||||
|
self.counter += 1
|
||||||
|
return self.counter >= self.patience
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_cache(model: nn.Module, smiles_list: List[str], batch_size: int = 256) -> None:
|
||||||
|
"""
|
||||||
|
预热 RDKit 特征缓存,避免训练时计算阻塞。
|
||||||
|
"""
|
||||||
|
unique_smiles = list(set(smiles_list))
|
||||||
|
logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...")
|
||||||
|
|
||||||
|
for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"):
|
||||||
|
batch = unique_smiles[i:i + batch_size]
|
||||||
|
model.rdkit_encoder(batch)
|
||||||
|
|
||||||
|
logger.success(f"Cache warmup complete. Cached {len(model.rdkit_encoder._cache)} SMILES.")
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch_delivery(
|
||||||
|
model: nn.Module,
|
||||||
|
loader: DataLoader,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
device: torch.device,
|
||||||
|
epoch: int = 0,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
单个 epoch 的预训练(仅 delivery 任务)。
|
||||||
|
"""
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
n_samples = 0
|
||||||
|
|
||||||
|
pbar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
|
||||||
|
for batch in pbar:
|
||||||
|
smiles = batch["smiles"]
|
||||||
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||||
|
targets = batch["targets"]["delivery"].to(device)
|
||||||
|
mask = batch["mask"]["delivery"].to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward: 只预测 delivery
|
||||||
|
pred = model.forward_delivery(smiles, tabular) # [B, 1]
|
||||||
|
pred = pred.squeeze(-1) # [B]
|
||||||
|
|
||||||
|
# 计算损失(仅对有效样本)
|
||||||
|
if mask.any():
|
||||||
|
loss = nn.functional.mse_loss(pred[mask], targets[mask])
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item() * mask.sum().item()
|
||||||
|
n_samples += mask.sum().item()
|
||||||
|
|
||||||
|
pbar.set_postfix({"loss": total_loss / max(n_samples, 1)})
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(n_samples, 1)
|
||||||
|
return {"loss": avg_loss, "n_samples": n_samples}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_delivery(
|
||||||
|
model: nn.Module,
|
||||||
|
loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
验证(仅 delivery 任务)。
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
n_samples = 0
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
smiles = batch["smiles"]
|
||||||
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||||
|
targets = batch["targets"]["delivery"].to(device)
|
||||||
|
mask = batch["mask"]["delivery"].to(device)
|
||||||
|
|
||||||
|
pred = model.forward_delivery(smiles, tabular).squeeze(-1)
|
||||||
|
|
||||||
|
if mask.any():
|
||||||
|
loss = nn.functional.mse_loss(pred[mask], targets[mask])
|
||||||
|
total_loss += loss.item() * mask.sum().item()
|
||||||
|
n_samples += mask.sum().item()
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(n_samples, 1)
|
||||||
|
return {"loss": avg_loss, "n_samples": n_samples}
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain(
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
model: nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-5,
|
||||||
|
epochs: int = 50,
|
||||||
|
patience: int = 10,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
预训练循环。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
训练历史和最佳验证损失
|
||||||
|
"""
|
||||||
|
model = model.to(device)
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode="min", factor=0.5, patience=5, verbose=True
|
||||||
|
)
|
||||||
|
early_stopping = EarlyStopping(patience=patience)
|
||||||
|
|
||||||
|
history = {"train": [], "val": []}
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
best_state = None
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
# Train
|
||||||
|
train_metrics = train_epoch_delivery(model, train_loader, optimizer, device, epoch)
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
val_metrics = validate_delivery(model, val_loader, device)
|
||||||
|
|
||||||
|
# Log
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {epoch + 1}/{epochs} | "
|
||||||
|
f"Train Loss: {train_metrics['loss']:.4f} | "
|
||||||
|
f"Val Loss: {val_metrics['loss']:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
history["train"].append(train_metrics)
|
||||||
|
history["val"].append(val_metrics)
|
||||||
|
|
||||||
|
# Learning rate scheduling
|
||||||
|
scheduler.step(val_metrics["loss"])
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if val_metrics["loss"] < best_val_loss:
|
||||||
|
best_val_loss = val_metrics["loss"]
|
||||||
|
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||||
|
logger.info(f" -> New best model (val_loss={best_val_loss:.4f})")
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if early_stopping(val_metrics["loss"]):
|
||||||
|
logger.info(f"Early stopping at epoch {epoch + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Restore best model
|
||||||
|
if best_state is not None:
|
||||||
|
model.load_state_dict(best_state)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"history": history,
|
||||||
|
"best_val_loss": best_val_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
# 数据路径(已处理的 parquet 文件)
|
||||||
|
train_path: Path = PROCESSED_DATA_DIR / "train_pretrain.parquet",
|
||||||
|
val_path: Path = PROCESSED_DATA_DIR / "val_pretrain.parquet",
|
||||||
|
output_dir: Path = MODELS_DIR,
|
||||||
|
# 模型参数
|
||||||
|
d_model: int = 256,
|
||||||
|
num_heads: int = 8,
|
||||||
|
n_attn_layers: int = 4,
|
||||||
|
fusion_strategy: str = "attention",
|
||||||
|
head_hidden_dim: int = 128,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
# MPNN 参数(可选)
|
||||||
|
use_mpnn: bool = False, # 启用 MPNN,自动从默认路径加载 ensemble
|
||||||
|
mpnn_checkpoint: Optional[str] = None,
|
||||||
|
mpnn_ensemble_paths: Optional[str] = None, # 逗号分隔的路径列表
|
||||||
|
mpnn_device: str = "cpu",
|
||||||
|
# 训练参数
|
||||||
|
batch_size: int = 64,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-5,
|
||||||
|
epochs: int = 50,
|
||||||
|
patience: int = 10,
|
||||||
|
# 设备
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
在外部 LiON 数据上预训练 backbone + delivery head。
|
||||||
|
|
||||||
|
需要先运行 `make data_pretrain` 生成 parquet 文件。
|
||||||
|
|
||||||
|
使用 --use-mpnn 启用 MPNN encoder(自动从 models/mpnn/all_amine_split_for_LiON 加载)。
|
||||||
|
|
||||||
|
产出:
|
||||||
|
- models/pretrain_delivery.pt: 包含 backbone + delivery head 权重
|
||||||
|
- models/pretrain_history.json: 训练历史
|
||||||
|
"""
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
device_obj = torch.device(device)
|
||||||
|
|
||||||
|
# 加载已处理的 parquet 文件
|
||||||
|
logger.info(f"Loading train data from {train_path}")
|
||||||
|
train_df = pd.read_parquet(train_path)
|
||||||
|
train_dataset = ExternalDeliveryDataset(train_df)
|
||||||
|
|
||||||
|
logger.info(f"Loading val data from {val_path}")
|
||||||
|
val_df = pd.read_parquet(val_path)
|
||||||
|
val_dataset = ExternalDeliveryDataset(val_df)
|
||||||
|
|
||||||
|
logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析 MPNN 配置
|
||||||
|
# 优先级:mpnn_checkpoint > mpnn_ensemble_paths > use_mpnn(自动查找)
|
||||||
|
ensemble_paths_list = None
|
||||||
|
if mpnn_ensemble_paths:
|
||||||
|
ensemble_paths_list = mpnn_ensemble_paths.split(",")
|
||||||
|
elif use_mpnn and mpnn_checkpoint is None:
|
||||||
|
# --use-mpnn 但没有指定具体路径,自动查找
|
||||||
|
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
|
||||||
|
ensemble_paths_list = find_mpnn_ensemble_paths()
|
||||||
|
logger.info(f"Found {len(ensemble_paths_list)} MPNN models")
|
||||||
|
|
||||||
|
enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
logger.info(f"Creating model (use_mpnn={enable_mpnn})...")
|
||||||
|
if enable_mpnn:
|
||||||
|
model = LNPModel(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mpnn_checkpoint=mpnn_checkpoint,
|
||||||
|
mpnn_ensemble_paths=ensemble_paths_list,
|
||||||
|
mpnn_device=mpnn_device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = LNPModelWithoutMPNN(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
logger.info(f"Model parameters: {n_params:,}")
|
||||||
|
|
||||||
|
# 预热 RDKit 缓存(避免训练时阻塞)
|
||||||
|
all_smiles = train_df["smiles"].tolist() + val_df["smiles"].tolist()
|
||||||
|
warmup_cache(model, all_smiles, batch_size=256)
|
||||||
|
|
||||||
|
# 预训练
|
||||||
|
logger.info("Starting pretraining on external data (delivery only)...")
|
||||||
|
result = pretrain(
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
model=model,
|
||||||
|
device=device_obj,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
epochs=epochs,
|
||||||
|
patience=patience,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存预训练 checkpoint
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
checkpoint_path = output_dir / "pretrain_delivery.pt"
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"backbone_state_dict": model.get_backbone_state_dict(),
|
||||||
|
"delivery_head_state_dict": model.get_delivery_head_state_dict(),
|
||||||
|
"config": {
|
||||||
|
"d_model": d_model,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
"n_attn_layers": n_attn_layers,
|
||||||
|
"fusion_strategy": fusion_strategy,
|
||||||
|
"head_hidden_dim": head_hidden_dim,
|
||||||
|
"dropout": dropout,
|
||||||
|
"use_mpnn": enable_mpnn,
|
||||||
|
},
|
||||||
|
"best_val_loss": result["best_val_loss"],
|
||||||
|
},
|
||||||
|
checkpoint_path,
|
||||||
|
)
|
||||||
|
logger.success(f"Saved pretrain checkpoint to {checkpoint_path}")
|
||||||
|
|
||||||
|
# 保存训练历史
|
||||||
|
history_path = output_dir / "pretrain_history.json"
|
||||||
|
with open(history_path, "w") as f:
|
||||||
|
json.dump(result["history"], f, indent=2)
|
||||||
|
logger.success(f"Saved pretrain history to {history_path}")
|
||||||
|
|
||||||
|
logger.success(
|
||||||
|
f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
|
|
||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -12,7 +12,7 @@ import typer
|
|||||||
|
|
||||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||||
from lnp_ml.dataset import LNPDataset, collate_fn
|
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||||
from lnp_ml.modeling.models import LNPModelWithoutMPNN
|
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||||
from lnp_ml.modeling.trainer import (
|
from lnp_ml.modeling.trainer import (
|
||||||
train_epoch,
|
train_epoch,
|
||||||
validate,
|
validate,
|
||||||
@ -20,6 +20,21 @@ from lnp_ml.modeling.trainer import (
|
|||||||
LossWeights,
|
LossWeights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# MPNN ensemble 默认路径
|
||||||
|
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||||
|
|
||||||
|
|
||||||
|
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||||
|
"""
|
||||||
|
自动查找 MPNN ensemble 的 model.pt 文件。
|
||||||
|
|
||||||
|
在 base_dir 下查找所有 cv_*/fold_*/model_*/model.pt 文件。
|
||||||
|
"""
|
||||||
|
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||||
|
if not model_paths:
|
||||||
|
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||||
|
return [str(p) for p in model_paths]
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
@ -31,16 +46,35 @@ def create_model(
|
|||||||
fusion_strategy: str = "attention",
|
fusion_strategy: str = "attention",
|
||||||
head_hidden_dim: int = 128,
|
head_hidden_dim: int = 128,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
) -> LNPModelWithoutMPNN:
|
# MPNN 参数(可选)
|
||||||
"""创建模型"""
|
mpnn_checkpoint: Optional[str] = None,
|
||||||
return LNPModelWithoutMPNN(
|
mpnn_ensemble_paths: Optional[List[str]] = None,
|
||||||
d_model=d_model,
|
mpnn_device: str = "cpu",
|
||||||
num_heads=num_heads,
|
) -> Union[LNPModel, LNPModelWithoutMPNN]:
|
||||||
n_attn_layers=n_attn_layers,
|
"""创建模型(支持可选的 MPNN encoder)"""
|
||||||
fusion_strategy=fusion_strategy,
|
use_mpnn = mpnn_checkpoint is not None or mpnn_ensemble_paths is not None
|
||||||
head_hidden_dim=head_hidden_dim,
|
|
||||||
dropout=dropout,
|
if use_mpnn:
|
||||||
)
|
return LNPModel(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mpnn_checkpoint=mpnn_checkpoint,
|
||||||
|
mpnn_ensemble_paths=mpnn_ensemble_paths,
|
||||||
|
mpnn_device=mpnn_device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return LNPModelWithoutMPNN(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_model(
|
def train_model(
|
||||||
@ -191,6 +225,11 @@ def main(
|
|||||||
fusion_strategy: str = "attention",
|
fusion_strategy: str = "attention",
|
||||||
head_hidden_dim: int = 128,
|
head_hidden_dim: int = 128,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
|
# MPNN 参数(可选)
|
||||||
|
use_mpnn: bool = False, # 启用 MPNN,自动从默认路径加载 ensemble
|
||||||
|
mpnn_checkpoint: Optional[str] = None,
|
||||||
|
mpnn_ensemble_paths: Optional[str] = None, # 逗号分隔的路径列表
|
||||||
|
mpnn_device: str = "cpu",
|
||||||
# 训练参数
|
# 训练参数
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
lr: float = 1e-4,
|
lr: float = 1e-4,
|
||||||
@ -201,13 +240,20 @@ def main(
|
|||||||
tune: bool = False,
|
tune: bool = False,
|
||||||
n_trials: int = 20,
|
n_trials: int = 20,
|
||||||
epochs_per_trial: int = 30,
|
epochs_per_trial: int = 30,
|
||||||
|
# 预训练权重加载
|
||||||
|
init_from_pretrain: Optional[Path] = None,
|
||||||
|
load_delivery_head: bool = True,
|
||||||
|
freeze_backbone: bool = False, # 冻结 backbone,只训练 heads
|
||||||
# 设备
|
# 设备
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
训练 LNP 预测模型。
|
训练 LNP 预测模型(多任务 finetune)。
|
||||||
|
|
||||||
使用 --tune 启用超参数调优。
|
使用 --tune 启用超参数调优。
|
||||||
|
使用 --init-from-pretrain 从预训练 checkpoint 初始化 backbone。
|
||||||
|
使用 --use-mpnn 启用 MPNN encoder(自动从 models/mpnn/all_amine_split_for_LiON 加载)。
|
||||||
|
使用 --freeze-backbone 冻结 backbone,只训练多任务 heads。
|
||||||
"""
|
"""
|
||||||
logger.info(f"Using device: {device}")
|
logger.info(f"Using device: {device}")
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
@ -258,8 +304,21 @@ def main(
|
|||||||
lr = best_params["lr"]
|
lr = best_params["lr"]
|
||||||
weight_decay = best_params["weight_decay"]
|
weight_decay = best_params["weight_decay"]
|
||||||
|
|
||||||
|
# 解析 MPNN 配置
|
||||||
|
# 优先级:mpnn_checkpoint > mpnn_ensemble_paths > use_mpnn(自动查找)
|
||||||
|
ensemble_paths_list = None
|
||||||
|
if mpnn_ensemble_paths:
|
||||||
|
ensemble_paths_list = mpnn_ensemble_paths.split(",")
|
||||||
|
elif use_mpnn and mpnn_checkpoint is None:
|
||||||
|
# --use-mpnn 但没有指定具体路径,自动查找
|
||||||
|
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
|
||||||
|
ensemble_paths_list = find_mpnn_ensemble_paths()
|
||||||
|
logger.info(f"Found {len(ensemble_paths_list)} MPNN models")
|
||||||
|
|
||||||
|
enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None
|
||||||
|
|
||||||
# 创建模型
|
# 创建模型
|
||||||
logger.info("Creating model...")
|
logger.info(f"Creating model (use_mpnn={enable_mpnn})...")
|
||||||
model = create_model(
|
model = create_model(
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
@ -267,11 +326,48 @@ def main(
|
|||||||
fusion_strategy=fusion_strategy,
|
fusion_strategy=fusion_strategy,
|
||||||
head_hidden_dim=head_hidden_dim,
|
head_hidden_dim=head_hidden_dim,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
|
mpnn_checkpoint=mpnn_checkpoint,
|
||||||
|
mpnn_ensemble_paths=ensemble_paths_list,
|
||||||
|
mpnn_device=mpnn_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 加载预训练权重(如果指定)
|
||||||
|
if init_from_pretrain is not None:
|
||||||
|
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
|
||||||
|
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
|
||||||
|
|
||||||
|
# 检查配置是否兼容
|
||||||
|
pretrain_config = checkpoint.get("config", {})
|
||||||
|
if pretrain_config.get("d_model") != d_model:
|
||||||
|
logger.warning(
|
||||||
|
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
|
||||||
|
f"current={d_model}. Skipping pretrain loading."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 加载 backbone + (可选) delivery head
|
||||||
|
model.load_pretrain_weights(
|
||||||
|
pretrain_state_dict=checkpoint["model_state_dict"],
|
||||||
|
load_delivery_head=load_delivery_head,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
logger.success(
|
||||||
|
f"Loaded pretrain weights (backbone + delivery_head={load_delivery_head})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 冻结 backbone(如果指定)
|
||||||
|
if freeze_backbone:
|
||||||
|
logger.info("Freezing backbone (token_projector, cross_attention, fusion)...")
|
||||||
|
frozen_count = 0
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
|
||||||
|
param.requires_grad = False
|
||||||
|
frozen_count += 1
|
||||||
|
logger.info(f"Frozen {frozen_count} parameter tensors")
|
||||||
|
|
||||||
# 打印模型信息
|
# 打印模型信息
|
||||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_params_total = sum(p.numel() for p in model.parameters())
|
||||||
logger.info(f"Model parameters: {n_params:,}")
|
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
|
||||||
|
|
||||||
# 训练
|
# 训练
|
||||||
logger.info("Starting training...")
|
logger.info("Starting training...")
|
||||||
@ -297,8 +393,10 @@ def main(
|
|||||||
"fusion_strategy": fusion_strategy,
|
"fusion_strategy": fusion_strategy,
|
||||||
"head_hidden_dim": head_hidden_dim,
|
"head_hidden_dim": head_hidden_dim,
|
||||||
"dropout": dropout,
|
"dropout": dropout,
|
||||||
|
"use_mpnn": enable_mpnn,
|
||||||
},
|
},
|
||||||
"best_val_loss": result["best_val_loss"],
|
"best_val_loss": result["best_val_loss"],
|
||||||
|
"init_from_pretrain": str(init_from_pretrain) if init_from_pretrain else None,
|
||||||
}, model_path)
|
}, model_path)
|
||||||
logger.success(f"Saved model to {model_path}")
|
logger.success(f"Saved model to {model_path}")
|
||||||
|
|
||||||
|
|||||||
2446
models/history.json
2446
models/history.json
File diff suppressed because it is too large
Load Diff
BIN
models/model.pt
BIN
models/model.pt
Binary file not shown.
BIN
models/pretrain_delivery.pt
Normal file
BIN
models/pretrain_delivery.pt
Normal file
Binary file not shown.
286
models/pretrain_history.json
Normal file
286
models/pretrain_history.json
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
{
|
||||||
|
"train": [
|
||||||
|
{
|
||||||
|
"loss": 0.8151603855680482,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6867990841792819,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.645540308540888,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5923176541020599,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5720762926262872,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5477570670417328,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5280393017717573,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5122504676513313,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.49667307051028314,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.486139440352648,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.4749755339466122,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.4636757543530298,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.4543497681877452,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.4408158337956461,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.4419790126221837,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.42850686623585116,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.41607048387867007,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.427172136486513,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.4125568530569382,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.39480836287767923,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.3885056775666858,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.3894976457588827,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.3890058272899995,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.3741690826284791,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.3534914434345719,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.3349389765134386,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.32965143874976194,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.32094062546116675,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.32526135008251184,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.31289531808423826,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.3088379208288558,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.2994744991261045,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.2981521815160671,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.29143649446979303,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.29075756723379653,
|
||||||
|
"n_samples": 8721
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"val": [
|
||||||
|
{
|
||||||
|
"loss": 0.7625711447683281,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.7092331695236781,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.7014068689723995,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6595172673863646,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6312279044905191,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6349272860831151,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6587623598744133,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6093261837651732,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6125607111474924,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6005943137518024,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6876292386783289,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5940848466228036,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5820883587079644,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6302792748938035,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5849901610914275,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5830434826428553,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5643168952858116,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5592790719340829,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.600335100686833,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5646457721097674,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6288956836004376,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5771863222183704,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5738056593250687,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5636531712085593,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5465074849879163,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5701294839843508,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5570075802420438,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5711473401701241,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5576858864741552,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5624132422716871,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5655298555506272,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5568078993151677,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.567752199383958,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5683093779442603,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5741443767974497,
|
||||||
|
"n_samples": 969
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
98
scripts/process_external.py
Normal file
98
scripts/process_external.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
"""外部数据预处理脚本:external -> processed"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import typer
|
||||||
|
from loguru import logger
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
|
||||||
|
from lnp_ml.dataset import process_external_dataframe, LNPDatasetConfig, get_phys_cols, get_exp_cols, COMP_COLS, HELP_COLS
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
input_path: Path = EXTERNAL_DATA_DIR / "all_data_LiON.csv",
|
||||||
|
output_dir: Path = PROCESSED_DATA_DIR,
|
||||||
|
train_ratio: float = 0.9,
|
||||||
|
seed: int = 42,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
处理外部 LiON 数据,生成预训练用的 parquet 文件。
|
||||||
|
|
||||||
|
输出:
|
||||||
|
- processed/train_pretrain.parquet
|
||||||
|
- processed/val_pretrain.parquet
|
||||||
|
- processed/feature_columns_pretrain.txt
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading external data from {input_path}")
|
||||||
|
df = pd.read_csv(input_path)
|
||||||
|
logger.info(f"Loaded {len(df)} samples")
|
||||||
|
|
||||||
|
# 过滤掉 quantified_delivery 为空的行
|
||||||
|
if "quantified_delivery" in df.columns:
|
||||||
|
before_len = len(df)
|
||||||
|
df = df[df["quantified_delivery"].notna()].reset_index(drop=True)
|
||||||
|
logger.info(f"Filtered NaN delivery: {before_len} -> {len(df)} samples")
|
||||||
|
|
||||||
|
# 处理数据(列对齐、one-hot 生成)
|
||||||
|
logger.info("Processing dataframe (column alignment, one-hot encoding)...")
|
||||||
|
df = process_external_dataframe(df)
|
||||||
|
|
||||||
|
# 获取所需列
|
||||||
|
config = LNPDatasetConfig()
|
||||||
|
feature_cols = (
|
||||||
|
["smiles"]
|
||||||
|
+ config.comp_cols
|
||||||
|
+ config.phys_cols
|
||||||
|
+ config.help_cols
|
||||||
|
+ config.exp_cols
|
||||||
|
+ ["quantified_delivery"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 只保留需要的列
|
||||||
|
available_cols = [c for c in feature_cols if c in df.columns]
|
||||||
|
missing_cols = [c for c in feature_cols if c not in df.columns]
|
||||||
|
if missing_cols:
|
||||||
|
logger.warning(f"Missing columns (will be filled with 0): {missing_cols}")
|
||||||
|
for col in missing_cols:
|
||||||
|
df[col] = 0.0
|
||||||
|
|
||||||
|
df = df[feature_cols]
|
||||||
|
|
||||||
|
# 划分 train/val
|
||||||
|
logger.info(f"Splitting data: train_ratio={train_ratio}, seed={seed}")
|
||||||
|
train_df, val_df = train_test_split(
|
||||||
|
df, train_size=train_ratio, random_state=seed, shuffle=True
|
||||||
|
)
|
||||||
|
train_df = train_df.reset_index(drop=True)
|
||||||
|
val_df = val_df.reset_index(drop=True)
|
||||||
|
|
||||||
|
logger.info(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}")
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
train_path = output_dir / "train_pretrain.parquet"
|
||||||
|
val_path = output_dir / "val_pretrain.parquet"
|
||||||
|
|
||||||
|
train_df.to_parquet(train_path, index=False)
|
||||||
|
val_df.to_parquet(val_path, index=False)
|
||||||
|
|
||||||
|
logger.success(f"Saved train data to {train_path}")
|
||||||
|
logger.success(f"Saved val data to {val_path}")
|
||||||
|
|
||||||
|
# 保存特征列配置
|
||||||
|
cols_path = output_dir / "feature_columns_pretrain.txt"
|
||||||
|
with open(cols_path, "w") as f:
|
||||||
|
f.write("\n".join(feature_cols))
|
||||||
|
logger.success(f"Saved feature columns to {cols_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user