"""带类权重的训练器:处理分类任务的数据不均衡问题""" from typing import Dict, List, Optional, Tuple from dataclasses import dataclass, field import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm @dataclass class ClassWeights: """分类任务的类权重""" pdi: Optional[torch.Tensor] = None # [4] for 4 PDI classes ee: Optional[torch.Tensor] = None # [3] for 3 EE classes toxic: Optional[torch.Tensor] = None # [2] for binary toxic @dataclass class LossWeightsBalanced: """各任务的损失权重(与 trainer.py 的 LossWeights 兼容)""" size: float = 1.0 pdi: float = 1.0 ee: float = 1.0 delivery: float = 1.0 biodist: float = 1.0 toxic: float = 1.0 def compute_class_weights_from_loader( loader: DataLoader, n_pdi_classes: int = 4, n_ee_classes: int = 3, n_toxic_classes: int = 2, smoothing: float = 0.1, ) -> ClassWeights: """ 从 DataLoader 统计类别频次并计算类权重。 使用 inverse frequency 方式:weight_c = N / (n_classes * count_c) 加 smoothing 避免极端权重。 Args: loader: DataLoader(需要遍历一次) n_pdi_classes: PDI 类别数 n_ee_classes: EE 类别数 n_toxic_classes: toxic 类别数 smoothing: 平滑系数(防止除零和极端权重) Returns: ClassWeights 对象,包含各分类任务的类权重张量 """ pdi_counts = torch.zeros(n_pdi_classes) ee_counts = torch.zeros(n_ee_classes) toxic_counts = torch.zeros(n_toxic_classes) for batch in loader: targets = batch["targets"] mask = batch["mask"] # PDI if "pdi" in targets and "pdi" in mask: m = mask["pdi"] if m.any(): labels = targets["pdi"][m] for c in range(n_pdi_classes): pdi_counts[c] += (labels == c).sum().item() # EE if "ee" in targets and "ee" in mask: m = mask["ee"] if m.any(): labels = targets["ee"][m] for c in range(n_ee_classes): ee_counts[c] += (labels == c).sum().item() # Toxic if "toxic" in targets and "toxic" in mask: m = mask["toxic"] if m.any(): labels = targets["toxic"][m] for c in range(n_toxic_classes): toxic_counts[c] += (labels == c).sum().item() def counts_to_weights(counts: torch.Tensor, n_classes: int) -> Optional[torch.Tensor]: """将计数转换为类权重""" total = counts.sum().item() if total == 0: return None # Inverse frequency with smoothing counts = counts + smoothing weights = total / (n_classes * counts) # Normalize to mean=1 weights = weights / weights.mean() return weights return ClassWeights( pdi=counts_to_weights(pdi_counts, n_pdi_classes), ee=counts_to_weights(ee_counts, n_ee_classes), toxic=counts_to_weights(toxic_counts, n_toxic_classes), ) def compute_multitask_loss_balanced( outputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], mask: Dict[str, torch.Tensor], task_weights: Optional[LossWeightsBalanced] = None, class_weights: Optional[ClassWeights] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ 计算带类权重的多任务损失。 Args: outputs: 模型输出 targets: 真实标签 mask: 有效样本掩码 task_weights: 各任务权重 class_weights: 分类任务的类权重 Returns: (total_loss, loss_dict) 总损失和各任务损失 """ task_weights = task_weights or LossWeightsBalanced() class_weights = class_weights or ClassWeights() losses = {} device = next(iter(outputs.values())).device total_loss = torch.tensor(0.0, device=device) # size: MSE loss(回归任务,不需要类权重) if "size" in targets and mask["size"].any(): m = mask["size"] pred = outputs["size"][m].squeeze(-1) tgt = targets["size"][m] losses["size"] = F.mse_loss(pred, tgt) total_loss = total_loss + task_weights.size * losses["size"] # delivery: MSE loss(回归任务,不需要类权重) if "delivery" in targets and mask["delivery"].any(): m = mask["delivery"] pred = outputs["delivery"][m].squeeze(-1) tgt = targets["delivery"][m] losses["delivery"] = F.mse_loss(pred, tgt) total_loss = total_loss + task_weights.delivery * losses["delivery"] # pdi: CrossEntropy with class weights if "pdi" in targets and mask["pdi"].any(): m = mask["pdi"] pred = outputs["pdi"][m] tgt = targets["pdi"][m] weight = class_weights.pdi.to(device) if class_weights.pdi is not None else None losses["pdi"] = F.cross_entropy(pred, tgt, weight=weight) total_loss = total_loss + task_weights.pdi * losses["pdi"] # ee: CrossEntropy with class weights if "ee" in targets and mask["ee"].any(): m = mask["ee"] pred = outputs["ee"][m] tgt = targets["ee"][m] weight = class_weights.ee.to(device) if class_weights.ee is not None else None losses["ee"] = F.cross_entropy(pred, tgt, weight=weight) total_loss = total_loss + task_weights.ee * losses["ee"] # toxic: CrossEntropy with class weights if "toxic" in targets and mask["toxic"].any(): m = mask["toxic"] pred = outputs["toxic"][m] tgt = targets["toxic"][m] weight = class_weights.toxic.to(device) if class_weights.toxic is not None else None losses["toxic"] = F.cross_entropy(pred, tgt, weight=weight) total_loss = total_loss + task_weights.toxic * losses["toxic"] # biodist: KL divergence(分布任务,不需要类权重) if "biodist" in targets and mask["biodist"].any(): m = mask["biodist"] pred = outputs["biodist"][m] tgt = targets["biodist"][m] losses["biodist"] = F.kl_div( pred.log().clamp(min=-100), tgt, reduction="batchmean", ) total_loss = total_loss + task_weights.biodist * losses["biodist"] return total_loss, losses def train_epoch_balanced( model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, task_weights: Optional[LossWeightsBalanced] = None, class_weights: Optional[ClassWeights] = None, ) -> Dict[str, float]: """带类权重的训练一个 epoch""" model.train() total_loss = 0.0 task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]} n_batches = 0 for batch in tqdm(loader, desc="Training", leave=False): smiles = batch["smiles"] tabular = {k: v.to(device) for k, v in batch["tabular"].items()} targets = {k: v.to(device) for k, v in batch["targets"].items()} mask = {k: v.to(device) for k, v in batch["mask"].items()} optimizer.zero_grad() outputs = model(smiles, tabular) loss, losses = compute_multitask_loss_balanced( outputs, targets, mask, task_weights, class_weights ) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() for k, v in losses.items(): task_losses[k] += v.item() n_batches += 1 return { "loss": total_loss / n_batches, **{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0}, } @torch.no_grad() def validate_balanced( model: nn.Module, loader: DataLoader, device: torch.device, task_weights: Optional[LossWeightsBalanced] = None, class_weights: Optional[ClassWeights] = None, ) -> Dict[str, float]: """带类权重的验证""" model.eval() total_loss = 0.0 task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]} n_batches = 0 # 用于计算准确率 correct = {k: 0 for k in ["pdi", "ee", "toxic"]} total = {k: 0 for k in ["pdi", "ee", "toxic"]} for batch in tqdm(loader, desc="Validating", leave=False): smiles = batch["smiles"] tabular = {k: v.to(device) for k, v in batch["tabular"].items()} targets = {k: v.to(device) for k, v in batch["targets"].items()} mask = {k: v.to(device) for k, v in batch["mask"].items()} outputs = model(smiles, tabular) loss, losses = compute_multitask_loss_balanced( outputs, targets, mask, task_weights, class_weights ) total_loss += loss.item() for k, v in losses.items(): task_losses[k] += v.item() n_batches += 1 # 计算分类准确率 for k in ["pdi", "ee", "toxic"]: if k in targets and mask[k].any(): m = mask[k] pred = outputs[k][m].argmax(dim=-1) tgt = targets[k][m] correct[k] += (pred == tgt).sum().item() total[k] += m.sum().item() metrics = { "loss": total_loss / n_batches, **{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0}, } # 添加准确率 for k in ["pdi", "ee", "toxic"]: if total[k] > 0: metrics[f"acc_{k}"] = correct[k] / total[k] return metrics BACKBONE_PREFIXES = ("token_projector.", "cross_attention.", "fusion.") def build_optimizer( model: nn.Module, lr: float, weight_decay: float, backbone_lr_ratio: float = 1.0, ) -> torch.optim.AdamW: """ 构建 AdamW 优化器,支持分层学习率。 当 backbone_lr_ratio < 1.0 时,backbone 参数使用 lr * backbone_lr_ratio, 其余参数(task heads 等)使用 lr。backbone_lr_ratio = 1.0 等价于统一学习率。 """ if backbone_lr_ratio >= 1.0: return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) backbone_params = [] head_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if name.startswith(BACKBONE_PREFIXES): backbone_params.append(param) else: head_params.append(param) return torch.optim.AdamW( [ {"params": backbone_params, "lr": lr * backbone_lr_ratio}, {"params": head_params, "lr": lr}, ], weight_decay=weight_decay, ) class EarlyStoppingBalanced: """早停机制(与 trainer.py 的 EarlyStopping 兼容)""" def __init__(self, patience: int = 10, min_delta: float = 0.0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = float("inf") self.best_epoch = 0 self.should_stop = False def __call__(self, val_loss: float, epoch: int = 0) -> bool: if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.best_epoch = epoch self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.should_stop = True return self.should_stop def get_best_epoch(self) -> int: """获取最佳 epoch(1-indexed)""" return self.best_epoch + 1 def train_with_early_stopping( model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, lr: float = 1e-4, weight_decay: float = 1e-5, epochs: int = 100, patience: int = 15, task_weights: Optional[LossWeightsBalanced] = None, class_weights: Optional[ClassWeights] = None, backbone_lr_ratio: float = 1.0, ) -> Dict: """ 带早停的完整训练流程。 Returns: Dict with keys: history, best_val_loss, best_epoch, best_state """ model = model.to(device) optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=5 ) early_stopping = EarlyStoppingBalanced(patience=patience) history = {"train": [], "val": []} best_val_loss = float("inf") best_state = None for epoch in range(epochs): # Train train_metrics = train_epoch_balanced( model, train_loader, optimizer, device, task_weights, class_weights ) # Validate val_metrics = validate_balanced( model, val_loader, device, task_weights, class_weights ) history["train"].append(train_metrics) history["val"].append(val_metrics) # Learning rate scheduling scheduler.step(val_metrics["loss"]) # Save best model if val_metrics["loss"] < best_val_loss: best_val_loss = val_metrics["loss"] best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} # Early stopping if early_stopping(val_metrics["loss"], epoch): break # Restore best model if best_state is not None: model.load_state_dict(best_state) return { "history": history, "best_val_loss": best_val_loss, "best_epoch": early_stopping.get_best_epoch(), "best_state": best_state, "epochs_trained": len(history["train"]), } def train_fixed_epochs( model: nn.Module, train_loader: DataLoader, val_loader: Optional[DataLoader], device: torch.device, lr: float = 1e-4, weight_decay: float = 1e-5, epochs: int = 50, task_weights: Optional[LossWeightsBalanced] = None, class_weights: Optional[ClassWeights] = None, use_cosine_annealing: bool = True, use_swa: bool = False, swa_start_epoch: Optional[int] = None, backbone_lr_ratio: float = 1.0, ) -> Dict: """ 固定 epoch 数的训练(不使用 early stopping)。 用于外层 CV 训练和最终训练。 Args: model: 模型 train_loader: 训练数据 val_loader: 验证数据(可选,仅用于监控) device: 设备 lr: 学习率 weight_decay: 权重衰减 epochs: 训练轮数 task_weights: 任务权重 class_weights: 类权重 use_cosine_annealing: 是否使用 CosineAnnealingLR use_swa: 是否使用 SWA swa_start_epoch: SWA 开始的 epoch(默认为 epochs * 0.75) backbone_lr_ratio: backbone 学习率相对于 head 的比例(1.0 = 统一学习率) Returns: Dict with keys: history, final_state """ model = model.to(device) optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio) if use_cosine_annealing: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) else: scheduler = None # SWA setup swa_model = None swa_scheduler = None if use_swa: from torch.optim.swa_utils import AveragedModel, SWALR swa_model = AveragedModel(model) swa_start = swa_start_epoch or int(epochs * 0.75) swa_scheduler = SWALR(optimizer, swa_lr=lr * 0.1) history = {"train": [], "val": []} for epoch in range(epochs): # Train train_metrics = train_epoch_balanced( model, train_loader, optimizer, device, task_weights, class_weights ) history["train"].append(train_metrics) # Validate (optional) if val_loader is not None: val_metrics = validate_balanced( model, val_loader, device, task_weights, class_weights ) history["val"].append(val_metrics) # Scheduler step if use_swa and epoch >= swa_start: swa_model.update_parameters(model) swa_scheduler.step() elif scheduler is not None: scheduler.step() # Finalize SWA final_state = None if use_swa and swa_model is not None: # Update batch normalization statistics torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device) final_state = {k: v.cpu().clone() for k, v in swa_model.module.state_dict().items()} else: final_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} return { "history": history, "final_state": final_state, "epochs_trained": epochs, }