mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
515 lines
17 KiB
Python
515 lines
17 KiB
Python
"""带类权重的训练器:处理分类任务的数据不均衡问题"""
|
||
|
||
from typing import Dict, List, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch.utils.data import DataLoader
|
||
from tqdm import tqdm
|
||
|
||
|
||
@dataclass
|
||
class ClassWeights:
|
||
"""分类任务的类权重"""
|
||
pdi: Optional[torch.Tensor] = None # [4] for 4 PDI classes
|
||
ee: Optional[torch.Tensor] = None # [3] for 3 EE classes
|
||
toxic: Optional[torch.Tensor] = None # [2] for binary toxic
|
||
|
||
|
||
@dataclass
|
||
class LossWeightsBalanced:
|
||
"""各任务的损失权重(与 trainer.py 的 LossWeights 兼容)"""
|
||
size: float = 1.0
|
||
pdi: float = 1.0
|
||
ee: float = 1.0
|
||
delivery: float = 1.0
|
||
biodist: float = 1.0
|
||
toxic: float = 1.0
|
||
|
||
|
||
def compute_class_weights_from_loader(
|
||
loader: DataLoader,
|
||
n_pdi_classes: int = 4,
|
||
n_ee_classes: int = 3,
|
||
n_toxic_classes: int = 2,
|
||
smoothing: float = 0.1,
|
||
) -> ClassWeights:
|
||
"""
|
||
从 DataLoader 统计类别频次并计算类权重。
|
||
|
||
使用 inverse frequency 方式:weight_c = N / (n_classes * count_c)
|
||
加 smoothing 避免极端权重。
|
||
|
||
Args:
|
||
loader: DataLoader(需要遍历一次)
|
||
n_pdi_classes: PDI 类别数
|
||
n_ee_classes: EE 类别数
|
||
n_toxic_classes: toxic 类别数
|
||
smoothing: 平滑系数(防止除零和极端权重)
|
||
|
||
Returns:
|
||
ClassWeights 对象,包含各分类任务的类权重张量
|
||
"""
|
||
pdi_counts = torch.zeros(n_pdi_classes)
|
||
ee_counts = torch.zeros(n_ee_classes)
|
||
toxic_counts = torch.zeros(n_toxic_classes)
|
||
|
||
for batch in loader:
|
||
targets = batch["targets"]
|
||
mask = batch["mask"]
|
||
|
||
# PDI
|
||
if "pdi" in targets and "pdi" in mask:
|
||
m = mask["pdi"]
|
||
if m.any():
|
||
labels = targets["pdi"][m]
|
||
for c in range(n_pdi_classes):
|
||
pdi_counts[c] += (labels == c).sum().item()
|
||
|
||
# EE
|
||
if "ee" in targets and "ee" in mask:
|
||
m = mask["ee"]
|
||
if m.any():
|
||
labels = targets["ee"][m]
|
||
for c in range(n_ee_classes):
|
||
ee_counts[c] += (labels == c).sum().item()
|
||
|
||
# Toxic
|
||
if "toxic" in targets and "toxic" in mask:
|
||
m = mask["toxic"]
|
||
if m.any():
|
||
labels = targets["toxic"][m]
|
||
for c in range(n_toxic_classes):
|
||
toxic_counts[c] += (labels == c).sum().item()
|
||
|
||
def counts_to_weights(counts: torch.Tensor, n_classes: int) -> Optional[torch.Tensor]:
|
||
"""将计数转换为类权重"""
|
||
total = counts.sum().item()
|
||
if total == 0:
|
||
return None
|
||
# Inverse frequency with smoothing
|
||
counts = counts + smoothing
|
||
weights = total / (n_classes * counts)
|
||
# Normalize to mean=1
|
||
weights = weights / weights.mean()
|
||
return weights
|
||
|
||
return ClassWeights(
|
||
pdi=counts_to_weights(pdi_counts, n_pdi_classes),
|
||
ee=counts_to_weights(ee_counts, n_ee_classes),
|
||
toxic=counts_to_weights(toxic_counts, n_toxic_classes),
|
||
)
|
||
|
||
|
||
def compute_multitask_loss_balanced(
|
||
outputs: Dict[str, torch.Tensor],
|
||
targets: Dict[str, torch.Tensor],
|
||
mask: Dict[str, torch.Tensor],
|
||
task_weights: Optional[LossWeightsBalanced] = None,
|
||
class_weights: Optional[ClassWeights] = None,
|
||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||
"""
|
||
计算带类权重的多任务损失。
|
||
|
||
Args:
|
||
outputs: 模型输出
|
||
targets: 真实标签
|
||
mask: 有效样本掩码
|
||
task_weights: 各任务权重
|
||
class_weights: 分类任务的类权重
|
||
|
||
Returns:
|
||
(total_loss, loss_dict) 总损失和各任务损失
|
||
"""
|
||
task_weights = task_weights or LossWeightsBalanced()
|
||
class_weights = class_weights or ClassWeights()
|
||
|
||
losses = {}
|
||
device = next(iter(outputs.values())).device
|
||
total_loss = torch.tensor(0.0, device=device)
|
||
|
||
# size: MSE loss(回归任务,不需要类权重)
|
||
if "size" in targets and mask["size"].any():
|
||
m = mask["size"]
|
||
pred = outputs["size"][m].squeeze(-1)
|
||
tgt = targets["size"][m]
|
||
losses["size"] = F.mse_loss(pred, tgt)
|
||
total_loss = total_loss + task_weights.size * losses["size"]
|
||
|
||
# delivery: MSE loss(回归任务,不需要类权重)
|
||
if "delivery" in targets and mask["delivery"].any():
|
||
m = mask["delivery"]
|
||
pred = outputs["delivery"][m].squeeze(-1)
|
||
tgt = targets["delivery"][m]
|
||
losses["delivery"] = F.mse_loss(pred, tgt)
|
||
total_loss = total_loss + task_weights.delivery * losses["delivery"]
|
||
|
||
# pdi: CrossEntropy with class weights
|
||
if "pdi" in targets and mask["pdi"].any():
|
||
m = mask["pdi"]
|
||
pred = outputs["pdi"][m]
|
||
tgt = targets["pdi"][m]
|
||
weight = class_weights.pdi.to(device) if class_weights.pdi is not None else None
|
||
losses["pdi"] = F.cross_entropy(pred, tgt, weight=weight)
|
||
total_loss = total_loss + task_weights.pdi * losses["pdi"]
|
||
|
||
# ee: CrossEntropy with class weights
|
||
if "ee" in targets and mask["ee"].any():
|
||
m = mask["ee"]
|
||
pred = outputs["ee"][m]
|
||
tgt = targets["ee"][m]
|
||
weight = class_weights.ee.to(device) if class_weights.ee is not None else None
|
||
losses["ee"] = F.cross_entropy(pred, tgt, weight=weight)
|
||
total_loss = total_loss + task_weights.ee * losses["ee"]
|
||
|
||
# toxic: CrossEntropy with class weights
|
||
if "toxic" in targets and mask["toxic"].any():
|
||
m = mask["toxic"]
|
||
pred = outputs["toxic"][m]
|
||
tgt = targets["toxic"][m]
|
||
weight = class_weights.toxic.to(device) if class_weights.toxic is not None else None
|
||
losses["toxic"] = F.cross_entropy(pred, tgt, weight=weight)
|
||
total_loss = total_loss + task_weights.toxic * losses["toxic"]
|
||
|
||
# biodist: KL divergence(分布任务,不需要类权重)
|
||
if "biodist" in targets and mask["biodist"].any():
|
||
m = mask["biodist"]
|
||
pred = outputs["biodist"][m]
|
||
tgt = targets["biodist"][m]
|
||
losses["biodist"] = F.kl_div(
|
||
pred.log().clamp(min=-100),
|
||
tgt,
|
||
reduction="batchmean",
|
||
)
|
||
total_loss = total_loss + task_weights.biodist * losses["biodist"]
|
||
|
||
return total_loss, losses
|
||
|
||
|
||
def train_epoch_balanced(
|
||
model: nn.Module,
|
||
loader: DataLoader,
|
||
optimizer: torch.optim.Optimizer,
|
||
device: torch.device,
|
||
task_weights: Optional[LossWeightsBalanced] = None,
|
||
class_weights: Optional[ClassWeights] = None,
|
||
) -> Dict[str, float]:
|
||
"""带类权重的训练一个 epoch"""
|
||
model.train()
|
||
total_loss = 0.0
|
||
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
|
||
n_batches = 0
|
||
|
||
for batch in tqdm(loader, desc="Training", leave=False):
|
||
smiles = batch["smiles"]
|
||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||
targets = {k: v.to(device) for k, v in batch["targets"].items()}
|
||
mask = {k: v.to(device) for k, v in batch["mask"].items()}
|
||
|
||
optimizer.zero_grad()
|
||
outputs = model(smiles, tabular)
|
||
loss, losses = compute_multitask_loss_balanced(
|
||
outputs, targets, mask, task_weights, class_weights
|
||
)
|
||
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
for k, v in losses.items():
|
||
task_losses[k] += v.item()
|
||
n_batches += 1
|
||
|
||
return {
|
||
"loss": total_loss / n_batches,
|
||
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
|
||
}
|
||
|
||
|
||
@torch.no_grad()
|
||
def validate_balanced(
|
||
model: nn.Module,
|
||
loader: DataLoader,
|
||
device: torch.device,
|
||
task_weights: Optional[LossWeightsBalanced] = None,
|
||
class_weights: Optional[ClassWeights] = None,
|
||
) -> Dict[str, float]:
|
||
"""带类权重的验证"""
|
||
model.eval()
|
||
total_loss = 0.0
|
||
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
|
||
n_batches = 0
|
||
|
||
# 用于计算准确率
|
||
correct = {k: 0 for k in ["pdi", "ee", "toxic"]}
|
||
total = {k: 0 for k in ["pdi", "ee", "toxic"]}
|
||
|
||
for batch in tqdm(loader, desc="Validating", leave=False):
|
||
smiles = batch["smiles"]
|
||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||
targets = {k: v.to(device) for k, v in batch["targets"].items()}
|
||
mask = {k: v.to(device) for k, v in batch["mask"].items()}
|
||
|
||
outputs = model(smiles, tabular)
|
||
loss, losses = compute_multitask_loss_balanced(
|
||
outputs, targets, mask, task_weights, class_weights
|
||
)
|
||
|
||
total_loss += loss.item()
|
||
for k, v in losses.items():
|
||
task_losses[k] += v.item()
|
||
n_batches += 1
|
||
|
||
# 计算分类准确率
|
||
for k in ["pdi", "ee", "toxic"]:
|
||
if k in targets and mask[k].any():
|
||
m = mask[k]
|
||
pred = outputs[k][m].argmax(dim=-1)
|
||
tgt = targets[k][m]
|
||
correct[k] += (pred == tgt).sum().item()
|
||
total[k] += m.sum().item()
|
||
|
||
metrics = {
|
||
"loss": total_loss / n_batches,
|
||
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
|
||
}
|
||
|
||
# 添加准确率
|
||
for k in ["pdi", "ee", "toxic"]:
|
||
if total[k] > 0:
|
||
metrics[f"acc_{k}"] = correct[k] / total[k]
|
||
|
||
return metrics
|
||
|
||
|
||
BACKBONE_PREFIXES = ("token_projector.", "cross_attention.", "fusion.")
|
||
|
||
|
||
def build_optimizer(
|
||
model: nn.Module,
|
||
lr: float,
|
||
weight_decay: float,
|
||
backbone_lr_ratio: float = 1.0,
|
||
) -> torch.optim.AdamW:
|
||
"""
|
||
构建 AdamW 优化器,支持分层学习率。
|
||
|
||
当 backbone_lr_ratio < 1.0 时,backbone 参数使用 lr * backbone_lr_ratio,
|
||
其余参数(task heads 等)使用 lr。backbone_lr_ratio = 1.0 等价于统一学习率。
|
||
"""
|
||
if backbone_lr_ratio >= 1.0:
|
||
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||
|
||
backbone_params = []
|
||
head_params = []
|
||
for name, param in model.named_parameters():
|
||
if not param.requires_grad:
|
||
continue
|
||
if name.startswith(BACKBONE_PREFIXES):
|
||
backbone_params.append(param)
|
||
else:
|
||
head_params.append(param)
|
||
|
||
return torch.optim.AdamW(
|
||
[
|
||
{"params": backbone_params, "lr": lr * backbone_lr_ratio},
|
||
{"params": head_params, "lr": lr},
|
||
],
|
||
weight_decay=weight_decay,
|
||
)
|
||
|
||
|
||
class EarlyStoppingBalanced:
|
||
"""早停机制(与 trainer.py 的 EarlyStopping 兼容)"""
|
||
|
||
def __init__(self, patience: int = 10, min_delta: float = 0.0):
|
||
self.patience = patience
|
||
self.min_delta = min_delta
|
||
self.counter = 0
|
||
self.best_loss = float("inf")
|
||
self.best_epoch = 0
|
||
self.should_stop = False
|
||
|
||
def __call__(self, val_loss: float, epoch: int = 0) -> bool:
|
||
if val_loss < self.best_loss - self.min_delta:
|
||
self.best_loss = val_loss
|
||
self.best_epoch = epoch
|
||
self.counter = 0
|
||
else:
|
||
self.counter += 1
|
||
if self.counter >= self.patience:
|
||
self.should_stop = True
|
||
return self.should_stop
|
||
|
||
def get_best_epoch(self) -> int:
|
||
"""获取最佳 epoch(1-indexed)"""
|
||
return self.best_epoch + 1
|
||
|
||
|
||
def train_with_early_stopping(
|
||
model: nn.Module,
|
||
train_loader: DataLoader,
|
||
val_loader: DataLoader,
|
||
device: torch.device,
|
||
lr: float = 1e-4,
|
||
weight_decay: float = 1e-5,
|
||
epochs: int = 100,
|
||
patience: int = 15,
|
||
task_weights: Optional[LossWeightsBalanced] = None,
|
||
class_weights: Optional[ClassWeights] = None,
|
||
backbone_lr_ratio: float = 1.0,
|
||
) -> Dict:
|
||
"""
|
||
带早停的完整训练流程。
|
||
|
||
Returns:
|
||
Dict with keys: history, best_val_loss, best_epoch, best_state
|
||
"""
|
||
model = model.to(device)
|
||
optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio)
|
||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||
optimizer, mode="min", factor=0.5, patience=5
|
||
)
|
||
early_stopping = EarlyStoppingBalanced(patience=patience)
|
||
|
||
history = {"train": [], "val": []}
|
||
best_val_loss = float("inf")
|
||
best_state = None
|
||
|
||
for epoch in range(epochs):
|
||
# Train
|
||
train_metrics = train_epoch_balanced(
|
||
model, train_loader, optimizer, device, task_weights, class_weights
|
||
)
|
||
|
||
# Validate
|
||
val_metrics = validate_balanced(
|
||
model, val_loader, device, task_weights, class_weights
|
||
)
|
||
|
||
history["train"].append(train_metrics)
|
||
history["val"].append(val_metrics)
|
||
|
||
# Learning rate scheduling
|
||
scheduler.step(val_metrics["loss"])
|
||
|
||
# Save best model
|
||
if val_metrics["loss"] < best_val_loss:
|
||
best_val_loss = val_metrics["loss"]
|
||
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||
|
||
# Early stopping
|
||
if early_stopping(val_metrics["loss"], epoch):
|
||
break
|
||
|
||
# Restore best model
|
||
if best_state is not None:
|
||
model.load_state_dict(best_state)
|
||
|
||
return {
|
||
"history": history,
|
||
"best_val_loss": best_val_loss,
|
||
"best_epoch": early_stopping.get_best_epoch(),
|
||
"best_state": best_state,
|
||
"epochs_trained": len(history["train"]),
|
||
}
|
||
|
||
|
||
def train_fixed_epochs(
|
||
model: nn.Module,
|
||
train_loader: DataLoader,
|
||
val_loader: Optional[DataLoader],
|
||
device: torch.device,
|
||
lr: float = 1e-4,
|
||
weight_decay: float = 1e-5,
|
||
epochs: int = 50,
|
||
task_weights: Optional[LossWeightsBalanced] = None,
|
||
class_weights: Optional[ClassWeights] = None,
|
||
use_cosine_annealing: bool = True,
|
||
use_swa: bool = False,
|
||
swa_start_epoch: Optional[int] = None,
|
||
backbone_lr_ratio: float = 1.0,
|
||
) -> Dict:
|
||
"""
|
||
固定 epoch 数的训练(不使用 early stopping)。
|
||
|
||
用于外层 CV 训练和最终训练。
|
||
|
||
Args:
|
||
model: 模型
|
||
train_loader: 训练数据
|
||
val_loader: 验证数据(可选,仅用于监控)
|
||
device: 设备
|
||
lr: 学习率
|
||
weight_decay: 权重衰减
|
||
epochs: 训练轮数
|
||
task_weights: 任务权重
|
||
class_weights: 类权重
|
||
use_cosine_annealing: 是否使用 CosineAnnealingLR
|
||
use_swa: 是否使用 SWA
|
||
swa_start_epoch: SWA 开始的 epoch(默认为 epochs * 0.75)
|
||
backbone_lr_ratio: backbone 学习率相对于 head 的比例(1.0 = 统一学习率)
|
||
|
||
Returns:
|
||
Dict with keys: history, final_state
|
||
"""
|
||
model = model.to(device)
|
||
optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio)
|
||
|
||
if use_cosine_annealing:
|
||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||
else:
|
||
scheduler = None
|
||
|
||
# SWA setup
|
||
swa_model = None
|
||
swa_scheduler = None
|
||
if use_swa:
|
||
from torch.optim.swa_utils import AveragedModel, SWALR
|
||
swa_model = AveragedModel(model)
|
||
swa_start = swa_start_epoch or int(epochs * 0.75)
|
||
swa_scheduler = SWALR(optimizer, swa_lr=lr * 0.1)
|
||
|
||
history = {"train": [], "val": []}
|
||
|
||
for epoch in range(epochs):
|
||
# Train
|
||
train_metrics = train_epoch_balanced(
|
||
model, train_loader, optimizer, device, task_weights, class_weights
|
||
)
|
||
history["train"].append(train_metrics)
|
||
|
||
# Validate (optional)
|
||
if val_loader is not None:
|
||
val_metrics = validate_balanced(
|
||
model, val_loader, device, task_weights, class_weights
|
||
)
|
||
history["val"].append(val_metrics)
|
||
|
||
# Scheduler step
|
||
if use_swa and epoch >= swa_start:
|
||
swa_model.update_parameters(model)
|
||
swa_scheduler.step()
|
||
elif scheduler is not None:
|
||
scheduler.step()
|
||
|
||
# Finalize SWA
|
||
final_state = None
|
||
if use_swa and swa_model is not None:
|
||
# Update batch normalization statistics
|
||
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
|
||
final_state = {k: v.cpu().clone() for k, v in swa_model.module.state_dict().items()}
|
||
else:
|
||
final_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||
|
||
return {
|
||
"history": history,
|
||
"final_state": final_state,
|
||
"epochs_trained": epochs,
|
||
}
|
||
|