diff --git a/lnp_ml/modeling/trainer.py b/lnp_ml/modeling/trainer.py index 46119ff..dbcd50c 100644 --- a/lnp_ml/modeling/trainer.py +++ b/lnp_ml/modeling/trainer.py @@ -13,18 +13,18 @@ from tqdm import tqdm @dataclass class LossWeights: """各任务的损失权重""" - # size: float = 1.0 - # pdi: float = 1.0 - # ee: float = 1.0 - # delivery: float = 1.0 - # biodist: float = 1.0 - # toxic: float = 1.0 - size: float = 0.1 - pdi: float = 0.3 - ee: float = 0.3 + size: float = 1.0 + pdi: float = 1.0 + ee: float = 1.0 delivery: float = 1.0 biodist: float = 1.0 - toxic: float = 0.05 + toxic: float = 1.0 + # size: float = 0.1 + # pdi: float = 0.3 + # ee: float = 0.3 + # delivery: float = 1.0 + # biodist: float = 1.0 + # toxic: float = 0.05 def compute_multitask_loss(