lnp_ml/lnp_ml/modeling/trainer_balanced.py

515 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""带类权重的训练器:处理分类任务的数据不均衡问题"""
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
@dataclass
class ClassWeights:
"""分类任务的类权重"""
pdi: Optional[torch.Tensor] = None # [4] for 4 PDI classes
ee: Optional[torch.Tensor] = None # [3] for 3 EE classes
toxic: Optional[torch.Tensor] = None # [2] for binary toxic
@dataclass
class LossWeightsBalanced:
"""各任务的损失权重(与 trainer.py 的 LossWeights 兼容)"""
size: float = 1.0
pdi: float = 1.0
ee: float = 1.0
delivery: float = 1.0
biodist: float = 1.0
toxic: float = 1.0
def compute_class_weights_from_loader(
loader: DataLoader,
n_pdi_classes: int = 4,
n_ee_classes: int = 3,
n_toxic_classes: int = 2,
smoothing: float = 0.1,
) -> ClassWeights:
"""
从 DataLoader 统计类别频次并计算类权重。
使用 inverse frequency 方式weight_c = N / (n_classes * count_c)
加 smoothing 避免极端权重。
Args:
loader: DataLoader需要遍历一次
n_pdi_classes: PDI 类别数
n_ee_classes: EE 类别数
n_toxic_classes: toxic 类别数
smoothing: 平滑系数(防止除零和极端权重)
Returns:
ClassWeights 对象,包含各分类任务的类权重张量
"""
pdi_counts = torch.zeros(n_pdi_classes)
ee_counts = torch.zeros(n_ee_classes)
toxic_counts = torch.zeros(n_toxic_classes)
for batch in loader:
targets = batch["targets"]
mask = batch["mask"]
# PDI
if "pdi" in targets and "pdi" in mask:
m = mask["pdi"]
if m.any():
labels = targets["pdi"][m]
for c in range(n_pdi_classes):
pdi_counts[c] += (labels == c).sum().item()
# EE
if "ee" in targets and "ee" in mask:
m = mask["ee"]
if m.any():
labels = targets["ee"][m]
for c in range(n_ee_classes):
ee_counts[c] += (labels == c).sum().item()
# Toxic
if "toxic" in targets and "toxic" in mask:
m = mask["toxic"]
if m.any():
labels = targets["toxic"][m]
for c in range(n_toxic_classes):
toxic_counts[c] += (labels == c).sum().item()
def counts_to_weights(counts: torch.Tensor, n_classes: int) -> Optional[torch.Tensor]:
"""将计数转换为类权重"""
total = counts.sum().item()
if total == 0:
return None
# Inverse frequency with smoothing
counts = counts + smoothing
weights = total / (n_classes * counts)
# Normalize to mean=1
weights = weights / weights.mean()
return weights
return ClassWeights(
pdi=counts_to_weights(pdi_counts, n_pdi_classes),
ee=counts_to_weights(ee_counts, n_ee_classes),
toxic=counts_to_weights(toxic_counts, n_toxic_classes),
)
def compute_multitask_loss_balanced(
outputs: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
mask: Dict[str, torch.Tensor],
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
计算带类权重的多任务损失。
Args:
outputs: 模型输出
targets: 真实标签
mask: 有效样本掩码
task_weights: 各任务权重
class_weights: 分类任务的类权重
Returns:
(total_loss, loss_dict) 总损失和各任务损失
"""
task_weights = task_weights or LossWeightsBalanced()
class_weights = class_weights or ClassWeights()
losses = {}
device = next(iter(outputs.values())).device
total_loss = torch.tensor(0.0, device=device)
# size: MSE loss回归任务不需要类权重
if "size" in targets and mask["size"].any():
m = mask["size"]
pred = outputs["size"][m].squeeze(-1)
tgt = targets["size"][m]
losses["size"] = F.mse_loss(pred, tgt)
total_loss = total_loss + task_weights.size * losses["size"]
# delivery: MSE loss回归任务不需要类权重
if "delivery" in targets and mask["delivery"].any():
m = mask["delivery"]
pred = outputs["delivery"][m].squeeze(-1)
tgt = targets["delivery"][m]
losses["delivery"] = F.mse_loss(pred, tgt)
total_loss = total_loss + task_weights.delivery * losses["delivery"]
# pdi: CrossEntropy with class weights
if "pdi" in targets and mask["pdi"].any():
m = mask["pdi"]
pred = outputs["pdi"][m]
tgt = targets["pdi"][m]
weight = class_weights.pdi.to(device) if class_weights.pdi is not None else None
losses["pdi"] = F.cross_entropy(pred, tgt, weight=weight)
total_loss = total_loss + task_weights.pdi * losses["pdi"]
# ee: CrossEntropy with class weights
if "ee" in targets and mask["ee"].any():
m = mask["ee"]
pred = outputs["ee"][m]
tgt = targets["ee"][m]
weight = class_weights.ee.to(device) if class_weights.ee is not None else None
losses["ee"] = F.cross_entropy(pred, tgt, weight=weight)
total_loss = total_loss + task_weights.ee * losses["ee"]
# toxic: CrossEntropy with class weights
if "toxic" in targets and mask["toxic"].any():
m = mask["toxic"]
pred = outputs["toxic"][m]
tgt = targets["toxic"][m]
weight = class_weights.toxic.to(device) if class_weights.toxic is not None else None
losses["toxic"] = F.cross_entropy(pred, tgt, weight=weight)
total_loss = total_loss + task_weights.toxic * losses["toxic"]
# biodist: KL divergence分布任务不需要类权重
if "biodist" in targets and mask["biodist"].any():
m = mask["biodist"]
pred = outputs["biodist"][m]
tgt = targets["biodist"][m]
losses["biodist"] = F.kl_div(
pred.log().clamp(min=-100),
tgt,
reduction="batchmean",
)
total_loss = total_loss + task_weights.biodist * losses["biodist"]
return total_loss, losses
def train_epoch_balanced(
model: nn.Module,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device,
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
) -> Dict[str, float]:
"""带类权重的训练一个 epoch"""
model.train()
total_loss = 0.0
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
n_batches = 0
for batch in tqdm(loader, desc="Training", leave=False):
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = {k: v.to(device) for k, v in batch["targets"].items()}
mask = {k: v.to(device) for k, v in batch["mask"].items()}
optimizer.zero_grad()
outputs = model(smiles, tabular)
loss, losses = compute_multitask_loss_balanced(
outputs, targets, mask, task_weights, class_weights
)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
for k, v in losses.items():
task_losses[k] += v.item()
n_batches += 1
return {
"loss": total_loss / n_batches,
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
}
@torch.no_grad()
def validate_balanced(
model: nn.Module,
loader: DataLoader,
device: torch.device,
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
) -> Dict[str, float]:
"""带类权重的验证"""
model.eval()
total_loss = 0.0
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
n_batches = 0
# 用于计算准确率
correct = {k: 0 for k in ["pdi", "ee", "toxic"]}
total = {k: 0 for k in ["pdi", "ee", "toxic"]}
for batch in tqdm(loader, desc="Validating", leave=False):
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = {k: v.to(device) for k, v in batch["targets"].items()}
mask = {k: v.to(device) for k, v in batch["mask"].items()}
outputs = model(smiles, tabular)
loss, losses = compute_multitask_loss_balanced(
outputs, targets, mask, task_weights, class_weights
)
total_loss += loss.item()
for k, v in losses.items():
task_losses[k] += v.item()
n_batches += 1
# 计算分类准确率
for k in ["pdi", "ee", "toxic"]:
if k in targets and mask[k].any():
m = mask[k]
pred = outputs[k][m].argmax(dim=-1)
tgt = targets[k][m]
correct[k] += (pred == tgt).sum().item()
total[k] += m.sum().item()
metrics = {
"loss": total_loss / n_batches,
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
}
# 添加准确率
for k in ["pdi", "ee", "toxic"]:
if total[k] > 0:
metrics[f"acc_{k}"] = correct[k] / total[k]
return metrics
BACKBONE_PREFIXES = ("token_projector.", "cross_attention.", "fusion.")
def build_optimizer(
model: nn.Module,
lr: float,
weight_decay: float,
backbone_lr_ratio: float = 1.0,
) -> torch.optim.AdamW:
"""
构建 AdamW 优化器,支持分层学习率。
当 backbone_lr_ratio < 1.0 时backbone 参数使用 lr * backbone_lr_ratio
其余参数task heads 等)使用 lr。backbone_lr_ratio = 1.0 等价于统一学习率。
"""
if backbone_lr_ratio >= 1.0:
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
backbone_params = []
head_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if name.startswith(BACKBONE_PREFIXES):
backbone_params.append(param)
else:
head_params.append(param)
return torch.optim.AdamW(
[
{"params": backbone_params, "lr": lr * backbone_lr_ratio},
{"params": head_params, "lr": lr},
],
weight_decay=weight_decay,
)
class EarlyStoppingBalanced:
"""早停机制(与 trainer.py 的 EarlyStopping 兼容)"""
def __init__(self, patience: int = 10, min_delta: float = 0.0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float("inf")
self.best_epoch = 0
self.should_stop = False
def __call__(self, val_loss: float, epoch: int = 0) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.best_epoch = epoch
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
return self.should_stop
def get_best_epoch(self) -> int:
"""获取最佳 epoch1-indexed"""
return self.best_epoch + 1
def train_with_early_stopping(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 100,
patience: int = 15,
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
backbone_lr_ratio: float = 1.0,
) -> Dict:
"""
带早停的完整训练流程。
Returns:
Dict with keys: history, best_val_loss, best_epoch, best_state
"""
model = model.to(device)
optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
early_stopping = EarlyStoppingBalanced(patience=patience)
history = {"train": [], "val": []}
best_val_loss = float("inf")
best_state = None
for epoch in range(epochs):
# Train
train_metrics = train_epoch_balanced(
model, train_loader, optimizer, device, task_weights, class_weights
)
# Validate
val_metrics = validate_balanced(
model, val_loader, device, task_weights, class_weights
)
history["train"].append(train_metrics)
history["val"].append(val_metrics)
# Learning rate scheduling
scheduler.step(val_metrics["loss"])
# Save best model
if val_metrics["loss"] < best_val_loss:
best_val_loss = val_metrics["loss"]
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
# Early stopping
if early_stopping(val_metrics["loss"], epoch):
break
# Restore best model
if best_state is not None:
model.load_state_dict(best_state)
return {
"history": history,
"best_val_loss": best_val_loss,
"best_epoch": early_stopping.get_best_epoch(),
"best_state": best_state,
"epochs_trained": len(history["train"]),
}
def train_fixed_epochs(
model: nn.Module,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
device: torch.device,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 50,
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
use_cosine_annealing: bool = True,
use_swa: bool = False,
swa_start_epoch: Optional[int] = None,
backbone_lr_ratio: float = 1.0,
) -> Dict:
"""
固定 epoch 数的训练(不使用 early stopping
用于外层 CV 训练和最终训练。
Args:
model: 模型
train_loader: 训练数据
val_loader: 验证数据(可选,仅用于监控)
device: 设备
lr: 学习率
weight_decay: 权重衰减
epochs: 训练轮数
task_weights: 任务权重
class_weights: 类权重
use_cosine_annealing: 是否使用 CosineAnnealingLR
use_swa: 是否使用 SWA
swa_start_epoch: SWA 开始的 epoch默认为 epochs * 0.75
backbone_lr_ratio: backbone 学习率相对于 head 的比例1.0 = 统一学习率)
Returns:
Dict with keys: history, final_state
"""
model = model.to(device)
optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio)
if use_cosine_annealing:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
else:
scheduler = None
# SWA setup
swa_model = None
swa_scheduler = None
if use_swa:
from torch.optim.swa_utils import AveragedModel, SWALR
swa_model = AveragedModel(model)
swa_start = swa_start_epoch or int(epochs * 0.75)
swa_scheduler = SWALR(optimizer, swa_lr=lr * 0.1)
history = {"train": [], "val": []}
for epoch in range(epochs):
# Train
train_metrics = train_epoch_balanced(
model, train_loader, optimizer, device, task_weights, class_weights
)
history["train"].append(train_metrics)
# Validate (optional)
if val_loader is not None:
val_metrics = validate_balanced(
model, val_loader, device, task_weights, class_weights
)
history["val"].append(val_metrics)
# Scheduler step
if use_swa and epoch >= swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
elif scheduler is not None:
scheduler.step()
# Finalize SWA
final_state = None
if use_swa and swa_model is not None:
# Update batch normalization statistics
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
final_state = {k: v.cpu().clone() for k, v in swa_model.module.state_dict().items()}
else:
final_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
return {
"history": history,
"final_state": final_state,
"epochs_trained": epochs,
}