Compare commits

...

2 Commits

Author SHA1 Message Date
RYDE-WORK
93a6f8654d ... 2026-01-23 17:51:08 +08:00
RYDE-WORK
a56637c8ac Add loss visualization 2026-01-23 13:40:22 +08:00
31 changed files with 4166 additions and 2068 deletions

View File

@ -14,6 +14,7 @@ import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
from lnp_ml.modeling.visualization import plot_loss_curves
# MPNN ensemble 默认路径 # MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON" DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
@ -351,6 +352,15 @@ def main(
json.dump(result["history"], f, indent=2) json.dump(result["history"], f, indent=2)
logger.success(f"Saved pretrain history to {history_path}") logger.success(f"Saved pretrain history to {history_path}")
# 绘制 loss 曲线图
loss_plot_path = output_dir / "pretrain_loss_curves.png"
plot_loss_curves(
history=result["history"],
output_path=loss_plot_path,
title="Pretrain Loss Curves (Delivery)",
)
logger.success(f"Saved loss curves plot to {loss_plot_path}")
logger.success( logger.success(
f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}" f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}"
) )

View File

@ -16,6 +16,7 @@ import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
from lnp_ml.modeling.visualization import plot_loss_curves
# MPNN ensemble 默认路径 # MPNN ensemble 默认路径
@ -226,6 +227,15 @@ def train_fold(
with open(history_path, "w") as f: with open(history_path, "w") as f:
json.dump(history, f, indent=2) json.dump(history, f, indent=2)
# 绘制 loss 曲线图
loss_plot_path = fold_output_dir / "loss_curves.png"
plot_loss_curves(
history=history,
output_path=loss_plot_path,
title=f"Pretrain Fold {fold_idx} Loss Curves",
)
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
return { return {
"fold_idx": fold_idx, "fold_idx": fold_idx,
"best_val_loss": best_val_loss, "best_val_loss": best_val_loss,

View File

@ -19,6 +19,7 @@ from lnp_ml.modeling.trainer import (
EarlyStopping, EarlyStopping,
LossWeights, LossWeights,
) )
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
# MPNN ensemble 默认路径 # MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON" DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
@ -406,6 +407,15 @@ def main(
json.dump(result["history"], f, indent=2) json.dump(result["history"], f, indent=2)
logger.success(f"Saved training history to {history_path}") logger.success(f"Saved training history to {history_path}")
# 绘制多任务 loss 曲线图
loss_plot_path = output_dir / "loss_curves.png"
plot_multitask_loss_curves(
history=result["history"],
output_path=loss_plot_path,
title="Multi-task Training Loss Curves",
)
logger.success(f"Saved loss curves plot to {loss_plot_path}")
logger.success(f"Training complete! Best val_loss: {result['best_val_loss']:.4f}") logger.success(f"Training complete! Best val_loss: {result['best_val_loss']:.4f}")

View File

@ -22,6 +22,7 @@ from lnp_ml.modeling.trainer import (
EarlyStopping, EarlyStopping,
LossWeights, LossWeights,
) )
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
# MPNN ensemble 默认路径 # MPNN ensemble 默认路径
@ -158,6 +159,15 @@ def train_fold(
with open(history_path, "w") as f: with open(history_path, "w") as f:
json.dump(history, f, indent=2) json.dump(history, f, indent=2)
# 绘制多任务 loss 曲线图
loss_plot_path = fold_output_dir / "loss_curves.png"
plot_multitask_loss_curves(
history=history,
output_path=loss_plot_path,
title=f"Fold {fold_idx} Multi-task Loss Curves",
)
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
return { return {
"fold_idx": fold_idx, "fold_idx": fold_idx,
"best_val_loss": best_val_loss, "best_val_loss": best_val_loss,

View File

@ -19,6 +19,12 @@ class LossWeights:
delivery: float = 1.0 delivery: float = 1.0
biodist: float = 1.0 biodist: float = 1.0
toxic: float = 1.0 toxic: float = 1.0
# size: float = 0.1
# pdi: float = 0.3
# ee: float = 0.3
# delivery: float = 1.0
# biodist: float = 1.0
# toxic: float = 0.05
def compute_multitask_loss( def compute_multitask_loss(

View File

@ -0,0 +1,284 @@
"""训练过程可视化工具:绘制 loss 曲线"""
from pathlib import Path
from typing import Dict, List, Optional, Union
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端,避免 GUI 依赖
# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def plot_loss_curves(
history: Union[Dict[str, List[Dict]], List[Dict]],
output_path: Path,
title: str = "Training Loss Curves",
figsize: tuple = (12, 8),
) -> None:
"""
绘制训练过程中各个 loss 组成部分的变化曲线
支持两种 history 格式
1. 预训练格式单任务
{"train": [{"loss": 0.1, ...}, ...], "val": [{"loss": 0.1, ...}, ...]}
2. CV fold 格式
[{"epoch": 1, "train_loss": 0.1, "val_loss": 0.1, ...}, ...]
Args:
history: 训练历史记录
output_path: 输出图片路径
title: 图标题
figsize: 图片尺寸
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 判断 history 格式并统一处理
if isinstance(history, dict) and "train" in history:
# 格式1: {"train": [...], "val": [...]}
_plot_standard_history(history, output_path, title, figsize)
elif isinstance(history, list):
# 格式2: [{...}, {...}, ...]
_plot_flat_history(history, output_path, title, figsize)
else:
raise ValueError(f"Unsupported history format: {type(history)}")
def _plot_standard_history(
history: Dict[str, List[Dict]],
output_path: Path,
title: str,
figsize: tuple,
) -> None:
"""绘制标准格式的 historytrain/val 分开)"""
train_history = history.get("train", [])
val_history = history.get("val", [])
if not train_history:
return
epochs = list(range(1, len(train_history) + 1))
# 收集所有 loss 键
all_loss_keys = set()
for record in train_history + val_history:
for key in record.keys():
if key.startswith("loss") or key == "loss":
all_loss_keys.add(key)
# 分离总 loss 和各任务 loss
total_loss_key = "loss"
task_loss_keys = sorted([k for k in all_loss_keys if k != "loss"])
# 创建子图
n_subplots = 1 + (1 if task_loss_keys else 0)
fig, axes = plt.subplots(n_subplots, 1, figsize=(figsize[0], figsize[1] * n_subplots / 2))
if n_subplots == 1:
axes = [axes]
# 颜色配置
colors = plt.cm.tab10.colors
# 子图1总 loss
ax = axes[0]
train_total_loss = [r.get(total_loss_key, 0) for r in train_history]
val_total_loss = [r.get(total_loss_key, 0) for r in val_history]
ax.plot(epochs, train_total_loss, 'o-', label='Train Total Loss', color=colors[0], linewidth=2, markersize=4)
if val_total_loss:
ax.plot(epochs, val_total_loss, 's--', label='Val Total Loss', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(f'{title} - Total Loss', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
# 子图2各任务 loss如果有
if task_loss_keys:
ax = axes[1]
for i, key in enumerate(task_loss_keys):
task_name = key.replace("loss_", "").upper()
train_values = [r.get(key, 0) for r in train_history]
val_values = [r.get(key, 0) for r in val_history]
color = colors[i % len(colors)]
ax.plot(epochs, train_values, 'o-', label=f'Train {task_name}', color=color, alpha=0.8, linewidth=1.5, markersize=3)
if val_values and any(v > 0 for v in val_values):
ax.plot(epochs, val_values, 's--', label=f'Val {task_name}', color=color, alpha=0.5, linewidth=1.5, markersize=3)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(f'{title} - Per-Task Loss', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=9, ncol=2)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)
def _plot_flat_history(
history: List[Dict],
output_path: Path,
title: str,
figsize: tuple,
) -> None:
"""绘制扁平格式的 historyCV fold 格式)"""
if not history:
return
epochs = [r.get("epoch", i + 1) for i, r in enumerate(history)]
# 收集所有 loss 相关的键
loss_keys = set()
for record in history:
for key in record.keys():
if "loss" in key.lower():
loss_keys.add(key)
# 分类
train_keys = sorted([k for k in loss_keys if "train" in k.lower()])
val_keys = sorted([k for k in loss_keys if "val" in k.lower()])
# 创建子图
fig, ax = plt.subplots(1, 1, figsize=figsize)
colors = plt.cm.tab10.colors
color_idx = 0
# 绘制训练 loss
for key in train_keys:
values = [r.get(key, 0) for r in history]
label = key.replace("_", " ").title()
ax.plot(epochs, values, 'o-', label=label, color=colors[color_idx % len(colors)],
linewidth=2, markersize=4, alpha=0.9)
color_idx += 1
# 绘制验证 loss
for key in val_keys:
values = [r.get(key, 0) for r in history]
label = key.replace("_", " ").title()
ax.plot(epochs, values, 's--', label=label, color=colors[color_idx % len(colors)],
linewidth=2, markersize=4, alpha=0.7)
color_idx += 1
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)
def plot_multitask_loss_curves(
history: Dict[str, List[Dict]],
output_path: Path,
title: str = "Multi-task Training Loss",
figsize: tuple = (14, 10),
) -> None:
"""
专门用于多任务训练的 loss 曲线绘制
将各个任务的 loss 分别绘制在不同的子图中便于比较
Args:
history: {"train": [...], "val": [...]} 格式的训练历史
output_path: 输出路径
title: 图标题
figsize: 图片尺寸
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
train_history = history.get("train", [])
val_history = history.get("val", [])
if not train_history:
return
epochs = list(range(1, len(train_history) + 1))
# 提取所有任务的 loss 键
task_keys = set()
for record in train_history:
for key in record.keys():
if key.startswith("loss_"):
task_name = key.replace("loss_", "")
task_keys.add(task_name)
task_keys = sorted(task_keys)
# 计算子图布局
n_tasks = len(task_keys)
if n_tasks == 0:
# 只有总 loss使用简单绘图
_plot_standard_history(history, output_path, title, figsize)
return
# 包含总 loss共 n_tasks + 1 个子图
n_plots = n_tasks + 1
n_cols = min(3, n_plots)
n_rows = (n_plots + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(figsize[0], figsize[1] * n_rows / 2))
if n_plots == 1:
axes = [[axes]]
elif n_rows == 1:
axes = [axes]
axes_flat = [ax for row in axes for ax in (row if hasattr(row, '__iter__') else [row])]
colors = plt.cm.tab10.colors
# 子图1总 loss
ax = axes_flat[0]
train_total = [r.get("loss", 0) for r in train_history]
val_total = [r.get("loss", 0) for r in val_history]
ax.plot(epochs, train_total, 'o-', label='Train', color=colors[0], linewidth=2, markersize=4)
if val_total:
ax.plot(epochs, val_total, 's--', label='Val', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Total Loss', fontweight='bold')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
# 各任务子图
for idx, task in enumerate(task_keys):
ax = axes_flat[idx + 1]
key = f"loss_{task}"
train_values = [r.get(key, 0) for r in train_history]
val_values = [r.get(key, 0) for r in val_history]
# 只绘制有值的数据
if any(v > 0 for v in train_values):
ax.plot(epochs, train_values, 'o-', label='Train', color=colors[0], linewidth=2, markersize=4)
if val_values and any(v > 0 for v in val_values):
ax.plot(epochs, val_values, 's--', label='Val', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title(f'{task.upper()} Loss', fontweight='bold')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
# 隐藏多余的子图
for idx in range(n_plots, len(axes_flat)):
axes_flat[idx].set_visible(False)
plt.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)

View File

@ -11,6 +11,6 @@
"batch_size": 32, "batch_size": 32,
"epochs": 100, "epochs": 100,
"patience": 15, "patience": 15,
"init_from_pretrain": null, "init_from_pretrain": "models/pretrain_delivery.pt",
"freeze_backbone": false "freeze_backbone": false
} }

View File

@ -2,38 +2,38 @@
"fold_results": [ "fold_results": [
{ {
"fold_idx": 0, "fold_idx": 0,
"best_val_loss": 5.7676777839660645, "best_val_loss": 0.9860520362854004,
"epochs_trained": 24, "epochs_trained": 40,
"final_train_loss": 1.4942118644714355 "final_train_loss": 0.5008097920152876
}, },
{ {
"fold_idx": 1, "fold_idx": 1,
"best_val_loss": 8.418675899505615, "best_val_loss": 2.4599782625834146,
"epochs_trained": 20, "epochs_trained": 38,
"final_train_loss": 1.4902493238449097 "final_train_loss": 0.564177993271086
}, },
{ {
"fold_idx": 2, "fold_idx": 2,
"best_val_loss": 3.5122547830854143, "best_val_loss": 0.7660132050514221,
"epochs_trained": 25, "epochs_trained": 43,
"final_train_loss": 1.7609570423762004 "final_train_loss": 0.6722757054699792
}, },
{ {
"fold_idx": 3, "fold_idx": 3,
"best_val_loss": 3.165306806564331, "best_val_loss": 1.065057098865509,
"epochs_trained": 21, "epochs_trained": 31,
"final_train_loss": 2.0073827385902403 "final_train_loss": 0.7323974437183804
}, },
{ {
"fold_idx": 4, "fold_idx": 4,
"best_val_loss": 2.996154228846232, "best_val_loss": 1.321769932905833,
"epochs_trained": 18, "epochs_trained": 36,
"final_train_loss": 1.9732873006300493 "final_train_loss": 0.5991987817817264
} }
], ],
"summary": { "summary": {
"val_loss_mean": 4.772013900393532, "val_loss_mean": 1.3197741071383158,
"val_loss_std": 2.0790222989111475 "val_loss_std": 0.5971552245392587
}, },
"config": { "config": {
"d_model": 256, "d_model": 256,
@ -48,7 +48,7 @@
"batch_size": 32, "batch_size": 32,
"epochs": 100, "epochs": 100,
"patience": 15, "patience": 15,
"init_from_pretrain": null, "init_from_pretrain": "models/pretrain_delivery.pt",
"freeze_backbone": false "freeze_backbone": false
} }
} }

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 287 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 270 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

View File

@ -2,293 +2,293 @@
"fold_results": [ "fold_results": [
{ {
"fold_idx": 0, "fold_idx": 0,
"n_samples": 95, "n_samples": 87,
"size": { "size": {
"n": 95, "n": 87,
"rmse": 0.5909209144067168, "rmse": 0.5578980261231201,
"mae": 0.376253614927593, "mae": 0.293820194814397,
"r2": 0.005712927161997228 "r2": -0.16350852977979646
}, },
"delivery": { "delivery": {
"n": 66, "n": 64,
"rmse": 1.3280577883458438, "rmse": 1.4180242556680394,
"mae": 0.5195405159964028, "mae": 0.489678907149937,
"r2": 0.03195999739694366 "r2": 0.006066588761596381
}, },
"pdi": { "pdi": {
"n": 95, "n": 87,
"accuracy": 0.6105263157894737, "accuracy": 0.7011494252873564,
"precision": 0.20350877192982456, "precision": 0.382983682983683,
"recall": 0.3333333333333333, "recall": 0.3981244671781756,
"f1": 0.25272331154684097 "f1": 0.38831483607603007
}, },
"ee": { "ee": {
"n": 95, "n": 87,
"accuracy": 0.6736842105263158, "accuracy": 0.6206896551724138,
"precision": 0.22456140350877193, "precision": 0.4700393567388641,
"recall": 0.3333333333333333, "recall": 0.5084656084656085,
"f1": 0.26834381551362685 "f1": 0.4656627032698024
}, },
"toxic": { "toxic": {
"n": 66, "n": 65,
"accuracy": 0.8939393939393939, "accuracy": 0.9846153846153847,
"precision": 0.44696969696969696, "precision": 0.9918032786885246,
"recall": 0.5, "recall": 0.9,
"f1": 0.472 "f1": 0.9403122130394859
}, },
"biodist": { "biodist": {
"n": 66, "n": 65,
"kl_divergence": 0.851655784204727, "kl_divergence": 0.20084619060054956,
"js_divergence": 0.21404831573756974 "js_divergence": 0.048292698388857934
} }
}, },
{ {
"fold_idx": 1, "fold_idx": 1,
"n_samples": 195, "n_samples": 87,
"size": { "size": {
"n": 193, "n": 87,
"rmse": 0.4425801645813746, "rmse": 0.40406802223294774,
"mae": 0.26432527161632796, "mae": 0.2981993521767101,
"r2": -0.026225211870033682 "r2": 0.09717283856858072
}, },
"delivery": { "delivery": {
"n": 123, "n": 58,
"rmse": 0.7771322048436382, "rmse": 0.545857341773909,
"mae": 0.6133777339870822, "mae": 0.3866841504832023,
"r2": -0.128644776760948 "r2": 0.39045172405838346
}, },
"pdi": { "pdi": {
"n": 195, "n": 87,
"accuracy": 0.7076923076923077, "accuracy": 0.8045977011494253,
"precision": 0.35384615384615387, "precision": 0.49282452707110247,
"recall": 0.5, "recall": 0.45567765567765567,
"f1": 0.4144144144144144 "f1": 0.4661145617667357
}, },
"ee": { "ee": {
"n": 195, "n": 87,
"accuracy": 0.4205128205128205, "accuracy": 0.7241379310344828,
"precision": 0.14017094017094017, "precision": 0.7064586357039188,
"recall": 0.3333333333333333, "recall": 0.6503496503496503,
"f1": 0.19735258724428398 "f1": 0.6654761904761903
}, },
"toxic": { "toxic": {
"n": 123, "n": 59,
"accuracy": 1.0, "accuracy": 0.9830508474576272,
"precision": 1.0, "precision": 0.9912280701754386,
"recall": 1.0, "recall": 0.8333333333333333,
"f1": 1.0 "f1": 0.8955752212389381
}, },
"biodist": { "biodist": {
"n": 123, "n": 58,
"kl_divergence": 0.9336461102028436, "kl_divergence": 0.185822128176519,
"js_divergence": 0.24870266224462317 "js_divergence": 0.049566546350752166
} }
}, },
{ {
"fold_idx": 2, "fold_idx": 2,
"n_samples": 51, "n_samples": 87,
"size": { "size": {
"n": 51, "n": 86,
"rmse": 0.6473513298834871, "rmse": 0.5861093745094258,
"mae": 0.5600235602434944, "mae": 0.35274335949919944,
"r2": -9.27515642706235 "r2": -0.3079648452189143
}, },
"delivery": { "delivery": {
"n": 44, "n": 61,
"rmse": 0.7721077356414991, "rmse": 0.5034529339588798,
"mae": 0.6167582499593581, "mae": 0.3725305872618175,
"r2": -0.4822886602727561 "r2": 0.596618413667312
}, },
"pdi": { "pdi": {
"n": 51, "n": 87,
"accuracy": 0.8823529411764706, "accuracy": 0.7701149425287356,
"precision": 0.29411764705882354, "precision": 0.7266666666666666,
"recall": 0.3333333333333333, "recall": 0.6349206349206349,
"f1": 0.3125 "f1": 0.6497584541062802
}, },
"ee": { "ee": {
"n": 51, "n": 87,
"accuracy": 0.8431372549019608, "accuracy": 0.5172413793103449,
"precision": 0.28104575163398693, "precision": 0.43155828639699606,
"recall": 0.3333333333333333, "recall": 0.40000766812361016,
"f1": 0.3049645390070922 "f1": 0.3980599647266314
}, },
"toxic": { "toxic": {
"n": 47, "n": 61,
"accuracy": 0.851063829787234, "accuracy": 1.0,
"precision": 0.425531914893617, "precision": 1.0,
"recall": 0.5, "recall": 1.0,
"f1": 0.4597701149425288 "f1": 1.0
}, },
"biodist": { "biodist": {
"n": 45, "n": 61,
"kl_divergence": 1.1049896129018548, "kl_divergence": 0.2646404257700098,
"js_divergence": 0.25485248115851133 "js_divergence": 0.07024299955112
} }
}, },
{ {
"fold_idx": 3, "fold_idx": 3,
"n_samples": 66, "n_samples": 86,
"size": { "size": {
"n": 66, "n": 86,
"rmse": 0.2407212117920812, "rmse": 0.32742961478246685,
"mae": 0.19363613562150436, "mae": 0.25193805472795355,
"r2": -0.11204941379936861 "r2": -0.09589933555096875
}, },
"delivery": { "delivery": {
"n": 62, "n": 68,
"rmse": 1.0041711455927012, "rmse": 0.7277366648519259,
"mae": 0.7132550483914993, "mae": 0.42998586144462664,
"r2": -0.63265374674746 "r2": 0.4053674615039361
}, },
"pdi": { "pdi": {
"n": 66, "n": 86,
"accuracy": 0.8484848484848485, "accuracy": 0.7906976744186046,
"precision": 0.42424242424242425, "precision": 0.7513157894736842,
"recall": 0.5, "recall": 0.6356534090909091,
"f1": 0.4590163934426229 "f1": 0.6544642857142857
}, },
"ee": { "ee": {
"n": 66, "n": 86,
"accuracy": 0.8181818181818182, "accuracy": 0.7441860465116279,
"precision": 0.27692307692307694, "precision": 0.6962905144216474,
"recall": 0.32727272727272727, "recall": 0.6327243018419488,
"f1": 0.3 "f1": 0.6557768628760061
}, },
"toxic": { "toxic": {
"n": 62, "n": 68,
"accuracy": 1.0, "accuracy": 1.0,
"precision": 1.0, "precision": 1.0,
"recall": 1.0, "recall": 1.0,
"f1": 1.0 "f1": 1.0
}, },
"biodist": { "biodist": {
"n": 62, "n": 68,
"kl_divergence": 0.9677978984139058, "kl_divergence": 0.3411994760293469,
"js_divergence": 0.2020309307244639 "js_divergence": 0.07812197338009717
} }
}, },
{ {
"fold_idx": 4, "fold_idx": 4,
"n_samples": 27, "n_samples": 87,
"size": { "size": {
"n": 27, "n": 86,
"rmse": 0.23392834445509142, "rmse": 0.28795189907825647,
"mae": 0.19066280788845485, "mae": 0.21156823080639506,
"r2": -0.2667651950955112 "r2": 0.2077479731717109
}, },
"delivery": { "delivery": {
"n": 15, "n": 59,
"rmse": 1.9603892288630869, "rmse": 0.8048179107025805,
"mae": 1.3892907698949177, "mae": 0.5188011898327682,
"r2": -0.29760739742916287 "r2": 0.24048521206149798
}, },
"pdi": { "pdi": {
"n": 27, "n": 87,
"accuracy": 0.8888888888888888, "accuracy": 0.6896551724137931,
"precision": 0.4444444444444444, "precision": 0.4101075268817204,
"recall": 0.5, "recall": 0.425,
"f1": 0.47058823529411764 "f1": 0.4174194267871083
}, },
"ee": { "ee": {
"n": 27, "n": 87,
"accuracy": 0.5925925925925926, "accuracy": 0.7011494252873564,
"precision": 0.19753086419753085, "precision": 0.6600529100529101,
"recall": 0.3333333333333333, "recall": 0.581219806763285,
"f1": 0.24806201550387597 "f1": 0.5953238953238954
}, },
"toxic": { "toxic": {
"n": 15, "n": 60,
"accuracy": 1.0, "accuracy": 0.95,
"precision": 1.0, "precision": 0.8240740740740741,
"recall": 1.0, "recall": 0.8818181818181818,
"f1": 1.0 "f1": 0.8498748957464553
}, },
"biodist": { "biodist": {
"n": 15, "n": 59,
"kl_divergence": 0.9389607012315264, "kl_divergence": 0.207699088365002,
"js_divergence": 0.2470218476598176 "js_divergence": 0.05288953180347253
} }
} }
], ],
"summary_stats": { "summary_stats": {
"size": { "size": {
"rmse_mean": 0.43110039302375025, "rmse_mean": 0.43269138734524343,
"rmse_std": 0.17179051271013462, "rmse_std": 0.12005218734930377,
"r2_mean": -1.9348966641330534, "r2_mean": -0.05249037976187758,
"r2_std": 3.6713441784129 "r2_std": 0.18417364026118202
}, },
"delivery": { "delivery": {
"rmse_mean": 1.1683716206573538, "rmse_mean": 0.7999778213910669,
"rmse_std": 0.4449374578352648, "rmse_std": 0.3285507063187239,
"r2_mean": -0.30184691676267666, "r2_mean": 0.32779788001054516,
"r2_std": 0.23809090378746706 "r2_std": 0.19664259417310184
}, },
"pdi": { "pdi": {
"accuracy_mean": 0.7875890604063979, "accuracy_mean": 0.751242983159583,
"accuracy_std": 0.11016791908756088, "accuracy_std": 0.04703609766404652,
"f1_mean": 0.3818484709395992, "f1_mean": 0.515214312890088,
"f1_std": 0.08529090446864619 "f1_std": 0.11451705845950987
}, },
"ee": { "ee": {
"accuracy_mean": 0.6696217393431015, "accuracy_mean": 0.6614808874632452,
"accuracy_std": 0.15503740047242787, "accuracy_std": 0.08343692429542732,
"f1_mean": 0.2637445914537758, "f1_mean": 0.5560599233345052,
"f1_std": 0.039213602228007696 "f1_std": 0.10638861908930271
}, },
"toxic": { "toxic": {
"accuracy_mean": 0.9490006447453256, "accuracy_mean": 0.9835332464146024,
"accuracy_std": 0.06391582554207781, "accuracy_std": 0.018265761928954075,
"f1_mean": 0.7863540229885058, "f1_mean": 0.9371524660049758,
"f1_std": 0.26169039387919035 "f1_std": 0.05874632007606866
}, },
"biodist": { "biodist": {
"kl_mean": 0.9594100213909715, "kl_mean": 0.24004146178828548,
"kl_std": 0.08240959093662605, "kl_std": 0.05720155140134621,
"js_mean": 0.23333124750499712, "js_mean": 0.05982274989485996,
"js_std": 0.021158533549255752 "js_std": 0.012080103435264247
} }
}, },
"overall": { "overall": {
"size": { "size": {
"n_samples": 432, "n_samples": 432,
"mse": 0.22604480336185886, "mse": 0.20179930058631468,
"rmse": 0.47544169291497657, "rmse": 0.44922077043065883,
"mae": 0.3084443360567093, "mae": 0.2817203010673876,
"r2": -0.2313078534105617 "r2": -0.09923811531698323
}, },
"delivery": { "delivery": {
"n_samples": 310, "n_samples": 310,
"mse": 1.0873755440675295, "mse": 0.760202623547859,
"rmse": 1.0427730069710903, "rmse": 0.8718959935381393,
"mae": 0.6513989447841361, "mae": 0.4398058238289049,
"r2": -0.09443640807387799 "r2": 0.23486100707044422
}, },
"pdi": { "pdi": {
"n_samples": 434, "n_samples": 434,
"accuracy": 0.7396313364055299, "accuracy": 0.7511520737327189,
"precision": 0.18490783410138248, "precision": 0.3287852263755878,
"recall": 0.25, "recall": 0.3184060228452752,
"f1": 0.21258278145695364 "f1": 0.3212571677885814
}, },
"ee": { "ee": {
"n_samples": 434, "n_samples": 434,
"accuracy": 0.5967741935483871, "accuracy": 0.6612903225806451,
"precision": 0.1993841416474211, "precision": 0.5836247086247086,
"recall": 0.33205128205128204, "recall": 0.5469842657342657,
"f1": 0.24915824915824913 "f1": 0.5597637622559741
}, },
"toxic": { "toxic": {
"n_samples": 313, "n_samples": 313,
"accuracy": 0.9552715654952076, "accuracy": 0.9840255591054313,
"precision": 0.4776357827476038, "precision": 0.918076923076923,
"recall": 0.5, "recall": 0.8895126612517916,
"f1": 0.48856209150326796 "f1": 0.9032337847028998
}, },
"biodist": { "biodist": {
"n_samples": 311, "n_samples": 311,
"kl_divergence": 0.9481034280166569, "kl_divergence": 0.24254521665201,
"js_divergence": 0.23285280825310384 "js_divergence": 0.06022985409160514
} }
} }
} }

Binary file not shown.

View File

@ -1,205 +1,293 @@
{ {
"train": [ "train": [
{ {
"loss": 0.7730368412685099, "loss": 0.7867247516051271,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.658895703010919, "loss": 0.6515523084370274,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.6059015260392299, "loss": 0.5990842743185651,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.5744731174349416, "loss": 0.5633418128920326,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.5452056020458733, "loss": 0.5453761521296815,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.5138543470936083, "loss": 0.49953126002250825,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4885380559178135, "loss": 0.49147369265204843,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.47587182296687974, "loss": 0.4659397622863399,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4671051038255316, "loss": 0.4653009635305819,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.46794115915756107, "loss": 0.4380375076610923,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4293930456997915, "loss": 0.4258159104875806,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.42624105651716415, "loss": 0.4144523660948226,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4131358770446828, "loss": 0.4008358244841981,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3946074267790835, "loss": 0.40240038808127093,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3898155013755344, "loss": 0.38176763174141226,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.37861797005733383, "loss": 0.37277743237904,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3775682858392304, "loss": 0.3573742728176747,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3800349080262064, "loss": 0.3491767022517619,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.36302345173031675, "loss": 0.3499675623860557,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3429561740842766, "loss": 0.35079578643841286,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3445638883004898, "loss": 0.3414057292594471,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.318970229203733, "loss": 0.300529963052257,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.30179278279904437, "loss": 0.2961940902990875,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.2887343142006437, "loss": 0.2907693515383844,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.29240367556855545, "loss": 0.2809350616734551,
"n_samples": 6783
},
{
"loss": 0.28143580470326973,
"n_samples": 6783
},
{
"loss": 0.2664423378391215,
"n_samples": 6783
},
{
"loss": 0.2745858784487654,
"n_samples": 6783
},
{
"loss": 0.26682337215652197,
"n_samples": 6783
},
{
"loss": 0.2681302405486289,
"n_samples": 6783
},
{
"loss": 0.26258669999889017,
"n_samples": 6783
},
{
"loss": 0.2608744821883436,
"n_samples": 6783
},
{
"loss": 0.239722755447208,
"n_samples": 6783
},
{
"loss": 0.24175641130912484,
"n_samples": 6783
},
{
"loss": 0.23785491213674798,
"n_samples": 6783
},
{
"loss": 0.23117999019839675,
"n_samples": 6783 "n_samples": 6783
} }
], ],
"val": [ "val": [
{ {
"loss": 0.7350345371841441, "loss": 0.7379055186813953,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.7165568811318536, "loss": 0.7269540305477178,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.7251406249862214, "loss": 0.6927794775152518,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6836505264587159, "loss": 0.6652627856533758,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6747132955771933, "loss": 0.6721626692594103,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6691136244936912, "loss": 0.6660812889787394,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6337480902323249, "loss": 0.6329412659009298,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6600317959527934, "loss": 0.6395346636332554,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6439923948855346, "loss": 0.6228037932749914,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.643800035575267, "loss": 0.627341245329581,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6181512585221839, "loss": 0.6399482499704272,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6442458634939151, "loss": 0.6136556721283145,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6344759362359862, "loss": 0.6253821217484764,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6501405371457472, "loss": 0.6535878175511408,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6098835162990152, "loss": 0.6123772015635566,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6366627322138894, "loss": 0.648116178003258,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6171610150646417, "loss": 0.6141229696318092,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6358801012273748, "loss": 0.6307853255273552,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6239976831059871, "loss": 0.6422428293329848,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6683828232827201, "loss": 0.6552245357949421,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6655785786478143, "loss": 0.6614342853503823,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6152775046503088, "loss": 0.6153881246378465,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6202247662153858, "loss": 0.6298057977526632,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.648199727435189, "loss": 0.6117005377321011,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6473217075085124, "loss": 0.618167482426702,
"n_samples": 2907
},
{
"loss": 0.6103124622220963,
"n_samples": 2907
},
{
"loss": 0.6203263100988888,
"n_samples": 2907
},
{
"loss": 0.617202088939065,
"n_samples": 2907
},
{
"loss": 0.6263537556599373,
"n_samples": 2907
},
{
"loss": 0.6461186489679168,
"n_samples": 2907
},
{
"loss": 0.6289772454163526,
"n_samples": 2907
},
{
"loss": 0.6347806630652919,
"n_samples": 2907
},
{
"loss": 0.6358688624302136,
"n_samples": 2907
},
{
"loss": 0.646931814478975,
"n_samples": 2907
},
{
"loss": 0.6245040218978556,
"n_samples": 2907
},
{
"loss": 0.6267098000482632,
"n_samples": 2907 "n_samples": 2907
} }
] ]

Binary file not shown.

After

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 277 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 280 KiB