Compare commits

..

No commits in common. "93a6f8654d5d70984d9ff0da66fe3f5b4a7f9975" and "871afc5988504a2750d67f6ebf5b601a14efbc7c" have entirely different histories.

31 changed files with 2068 additions and 4166 deletions

View File

@ -14,7 +14,6 @@ import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
from lnp_ml.modeling.visualization import plot_loss_curves
# MPNN ensemble 默认路径 # MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON" DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
@ -352,15 +351,6 @@ def main(
json.dump(result["history"], f, indent=2) json.dump(result["history"], f, indent=2)
logger.success(f"Saved pretrain history to {history_path}") logger.success(f"Saved pretrain history to {history_path}")
# 绘制 loss 曲线图
loss_plot_path = output_dir / "pretrain_loss_curves.png"
plot_loss_curves(
history=result["history"],
output_path=loss_plot_path,
title="Pretrain Loss Curves (Delivery)",
)
logger.success(f"Saved loss curves plot to {loss_plot_path}")
logger.success( logger.success(
f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}" f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}"
) )

View File

@ -16,7 +16,6 @@ import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
from lnp_ml.modeling.visualization import plot_loss_curves
# MPNN ensemble 默认路径 # MPNN ensemble 默认路径
@ -227,15 +226,6 @@ def train_fold(
with open(history_path, "w") as f: with open(history_path, "w") as f:
json.dump(history, f, indent=2) json.dump(history, f, indent=2)
# 绘制 loss 曲线图
loss_plot_path = fold_output_dir / "loss_curves.png"
plot_loss_curves(
history=history,
output_path=loss_plot_path,
title=f"Pretrain Fold {fold_idx} Loss Curves",
)
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
return { return {
"fold_idx": fold_idx, "fold_idx": fold_idx,
"best_val_loss": best_val_loss, "best_val_loss": best_val_loss,

View File

@ -19,7 +19,6 @@ from lnp_ml.modeling.trainer import (
EarlyStopping, EarlyStopping,
LossWeights, LossWeights,
) )
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
# MPNN ensemble 默认路径 # MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON" DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
@ -407,15 +406,6 @@ def main(
json.dump(result["history"], f, indent=2) json.dump(result["history"], f, indent=2)
logger.success(f"Saved training history to {history_path}") logger.success(f"Saved training history to {history_path}")
# 绘制多任务 loss 曲线图
loss_plot_path = output_dir / "loss_curves.png"
plot_multitask_loss_curves(
history=result["history"],
output_path=loss_plot_path,
title="Multi-task Training Loss Curves",
)
logger.success(f"Saved loss curves plot to {loss_plot_path}")
logger.success(f"Training complete! Best val_loss: {result['best_val_loss']:.4f}") logger.success(f"Training complete! Best val_loss: {result['best_val_loss']:.4f}")

View File

@ -22,7 +22,6 @@ from lnp_ml.modeling.trainer import (
EarlyStopping, EarlyStopping,
LossWeights, LossWeights,
) )
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
# MPNN ensemble 默认路径 # MPNN ensemble 默认路径
@ -159,15 +158,6 @@ def train_fold(
with open(history_path, "w") as f: with open(history_path, "w") as f:
json.dump(history, f, indent=2) json.dump(history, f, indent=2)
# 绘制多任务 loss 曲线图
loss_plot_path = fold_output_dir / "loss_curves.png"
plot_multitask_loss_curves(
history=history,
output_path=loss_plot_path,
title=f"Fold {fold_idx} Multi-task Loss Curves",
)
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
return { return {
"fold_idx": fold_idx, "fold_idx": fold_idx,
"best_val_loss": best_val_loss, "best_val_loss": best_val_loss,

View File

@ -19,12 +19,6 @@ class LossWeights:
delivery: float = 1.0 delivery: float = 1.0
biodist: float = 1.0 biodist: float = 1.0
toxic: float = 1.0 toxic: float = 1.0
# size: float = 0.1
# pdi: float = 0.3
# ee: float = 0.3
# delivery: float = 1.0
# biodist: float = 1.0
# toxic: float = 0.05
def compute_multitask_loss( def compute_multitask_loss(

View File

@ -1,284 +0,0 @@
"""训练过程可视化工具:绘制 loss 曲线"""
from pathlib import Path
from typing import Dict, List, Optional, Union
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端,避免 GUI 依赖
# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def plot_loss_curves(
history: Union[Dict[str, List[Dict]], List[Dict]],
output_path: Path,
title: str = "Training Loss Curves",
figsize: tuple = (12, 8),
) -> None:
"""
绘制训练过程中各个 loss 组成部分的变化曲线
支持两种 history 格式
1. 预训练格式单任务
{"train": [{"loss": 0.1, ...}, ...], "val": [{"loss": 0.1, ...}, ...]}
2. CV fold 格式
[{"epoch": 1, "train_loss": 0.1, "val_loss": 0.1, ...}, ...]
Args:
history: 训练历史记录
output_path: 输出图片路径
title: 图标题
figsize: 图片尺寸
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 判断 history 格式并统一处理
if isinstance(history, dict) and "train" in history:
# 格式1: {"train": [...], "val": [...]}
_plot_standard_history(history, output_path, title, figsize)
elif isinstance(history, list):
# 格式2: [{...}, {...}, ...]
_plot_flat_history(history, output_path, title, figsize)
else:
raise ValueError(f"Unsupported history format: {type(history)}")
def _plot_standard_history(
history: Dict[str, List[Dict]],
output_path: Path,
title: str,
figsize: tuple,
) -> None:
"""绘制标准格式的 historytrain/val 分开)"""
train_history = history.get("train", [])
val_history = history.get("val", [])
if not train_history:
return
epochs = list(range(1, len(train_history) + 1))
# 收集所有 loss 键
all_loss_keys = set()
for record in train_history + val_history:
for key in record.keys():
if key.startswith("loss") or key == "loss":
all_loss_keys.add(key)
# 分离总 loss 和各任务 loss
total_loss_key = "loss"
task_loss_keys = sorted([k for k in all_loss_keys if k != "loss"])
# 创建子图
n_subplots = 1 + (1 if task_loss_keys else 0)
fig, axes = plt.subplots(n_subplots, 1, figsize=(figsize[0], figsize[1] * n_subplots / 2))
if n_subplots == 1:
axes = [axes]
# 颜色配置
colors = plt.cm.tab10.colors
# 子图1总 loss
ax = axes[0]
train_total_loss = [r.get(total_loss_key, 0) for r in train_history]
val_total_loss = [r.get(total_loss_key, 0) for r in val_history]
ax.plot(epochs, train_total_loss, 'o-', label='Train Total Loss', color=colors[0], linewidth=2, markersize=4)
if val_total_loss:
ax.plot(epochs, val_total_loss, 's--', label='Val Total Loss', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(f'{title} - Total Loss', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
# 子图2各任务 loss如果有
if task_loss_keys:
ax = axes[1]
for i, key in enumerate(task_loss_keys):
task_name = key.replace("loss_", "").upper()
train_values = [r.get(key, 0) for r in train_history]
val_values = [r.get(key, 0) for r in val_history]
color = colors[i % len(colors)]
ax.plot(epochs, train_values, 'o-', label=f'Train {task_name}', color=color, alpha=0.8, linewidth=1.5, markersize=3)
if val_values and any(v > 0 for v in val_values):
ax.plot(epochs, val_values, 's--', label=f'Val {task_name}', color=color, alpha=0.5, linewidth=1.5, markersize=3)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(f'{title} - Per-Task Loss', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=9, ncol=2)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)
def _plot_flat_history(
history: List[Dict],
output_path: Path,
title: str,
figsize: tuple,
) -> None:
"""绘制扁平格式的 historyCV fold 格式)"""
if not history:
return
epochs = [r.get("epoch", i + 1) for i, r in enumerate(history)]
# 收集所有 loss 相关的键
loss_keys = set()
for record in history:
for key in record.keys():
if "loss" in key.lower():
loss_keys.add(key)
# 分类
train_keys = sorted([k for k in loss_keys if "train" in k.lower()])
val_keys = sorted([k for k in loss_keys if "val" in k.lower()])
# 创建子图
fig, ax = plt.subplots(1, 1, figsize=figsize)
colors = plt.cm.tab10.colors
color_idx = 0
# 绘制训练 loss
for key in train_keys:
values = [r.get(key, 0) for r in history]
label = key.replace("_", " ").title()
ax.plot(epochs, values, 'o-', label=label, color=colors[color_idx % len(colors)],
linewidth=2, markersize=4, alpha=0.9)
color_idx += 1
# 绘制验证 loss
for key in val_keys:
values = [r.get(key, 0) for r in history]
label = key.replace("_", " ").title()
ax.plot(epochs, values, 's--', label=label, color=colors[color_idx % len(colors)],
linewidth=2, markersize=4, alpha=0.7)
color_idx += 1
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)
def plot_multitask_loss_curves(
history: Dict[str, List[Dict]],
output_path: Path,
title: str = "Multi-task Training Loss",
figsize: tuple = (14, 10),
) -> None:
"""
专门用于多任务训练的 loss 曲线绘制
将各个任务的 loss 分别绘制在不同的子图中便于比较
Args:
history: {"train": [...], "val": [...]} 格式的训练历史
output_path: 输出路径
title: 图标题
figsize: 图片尺寸
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
train_history = history.get("train", [])
val_history = history.get("val", [])
if not train_history:
return
epochs = list(range(1, len(train_history) + 1))
# 提取所有任务的 loss 键
task_keys = set()
for record in train_history:
for key in record.keys():
if key.startswith("loss_"):
task_name = key.replace("loss_", "")
task_keys.add(task_name)
task_keys = sorted(task_keys)
# 计算子图布局
n_tasks = len(task_keys)
if n_tasks == 0:
# 只有总 loss使用简单绘图
_plot_standard_history(history, output_path, title, figsize)
return
# 包含总 loss共 n_tasks + 1 个子图
n_plots = n_tasks + 1
n_cols = min(3, n_plots)
n_rows = (n_plots + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(figsize[0], figsize[1] * n_rows / 2))
if n_plots == 1:
axes = [[axes]]
elif n_rows == 1:
axes = [axes]
axes_flat = [ax for row in axes for ax in (row if hasattr(row, '__iter__') else [row])]
colors = plt.cm.tab10.colors
# 子图1总 loss
ax = axes_flat[0]
train_total = [r.get("loss", 0) for r in train_history]
val_total = [r.get("loss", 0) for r in val_history]
ax.plot(epochs, train_total, 'o-', label='Train', color=colors[0], linewidth=2, markersize=4)
if val_total:
ax.plot(epochs, val_total, 's--', label='Val', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Total Loss', fontweight='bold')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
# 各任务子图
for idx, task in enumerate(task_keys):
ax = axes_flat[idx + 1]
key = f"loss_{task}"
train_values = [r.get(key, 0) for r in train_history]
val_values = [r.get(key, 0) for r in val_history]
# 只绘制有值的数据
if any(v > 0 for v in train_values):
ax.plot(epochs, train_values, 'o-', label='Train', color=colors[0], linewidth=2, markersize=4)
if val_values and any(v > 0 for v in val_values):
ax.plot(epochs, val_values, 's--', label='Val', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title(f'{task.upper()} Loss', fontweight='bold')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
# 隐藏多余的子图
for idx in range(n_plots, len(axes_flat)):
axes_flat[idx].set_visible(False)
plt.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)

View File

@ -11,6 +11,6 @@
"batch_size": 32, "batch_size": 32,
"epochs": 100, "epochs": 100,
"patience": 15, "patience": 15,
"init_from_pretrain": "models/pretrain_delivery.pt", "init_from_pretrain": null,
"freeze_backbone": false "freeze_backbone": false
} }

View File

@ -2,38 +2,38 @@
"fold_results": [ "fold_results": [
{ {
"fold_idx": 0, "fold_idx": 0,
"best_val_loss": 0.9860520362854004, "best_val_loss": 5.7676777839660645,
"epochs_trained": 40, "epochs_trained": 24,
"final_train_loss": 0.5008097920152876 "final_train_loss": 1.4942118644714355
}, },
{ {
"fold_idx": 1, "fold_idx": 1,
"best_val_loss": 2.4599782625834146, "best_val_loss": 8.418675899505615,
"epochs_trained": 38, "epochs_trained": 20,
"final_train_loss": 0.564177993271086 "final_train_loss": 1.4902493238449097
}, },
{ {
"fold_idx": 2, "fold_idx": 2,
"best_val_loss": 0.7660132050514221, "best_val_loss": 3.5122547830854143,
"epochs_trained": 43, "epochs_trained": 25,
"final_train_loss": 0.6722757054699792 "final_train_loss": 1.7609570423762004
}, },
{ {
"fold_idx": 3, "fold_idx": 3,
"best_val_loss": 1.065057098865509, "best_val_loss": 3.165306806564331,
"epochs_trained": 31, "epochs_trained": 21,
"final_train_loss": 0.7323974437183804 "final_train_loss": 2.0073827385902403
}, },
{ {
"fold_idx": 4, "fold_idx": 4,
"best_val_loss": 1.321769932905833, "best_val_loss": 2.996154228846232,
"epochs_trained": 36, "epochs_trained": 18,
"final_train_loss": 0.5991987817817264 "final_train_loss": 1.9732873006300493
} }
], ],
"summary": { "summary": {
"val_loss_mean": 1.3197741071383158, "val_loss_mean": 4.772013900393532,
"val_loss_std": 0.5971552245392587 "val_loss_std": 2.0790222989111475
}, },
"config": { "config": {
"d_model": 256, "d_model": 256,
@ -48,7 +48,7 @@
"batch_size": 32, "batch_size": 32,
"epochs": 100, "epochs": 100,
"patience": 15, "patience": 15,
"init_from_pretrain": "models/pretrain_delivery.pt", "init_from_pretrain": null,
"freeze_backbone": false "freeze_backbone": false
} }
} }

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 287 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 270 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

View File

@ -2,293 +2,293 @@
"fold_results": [ "fold_results": [
{ {
"fold_idx": 0, "fold_idx": 0,
"n_samples": 87, "n_samples": 95,
"size": { "size": {
"n": 87, "n": 95,
"rmse": 0.5578980261231201, "rmse": 0.5909209144067168,
"mae": 0.293820194814397, "mae": 0.376253614927593,
"r2": -0.16350852977979646 "r2": 0.005712927161997228
}, },
"delivery": { "delivery": {
"n": 64, "n": 66,
"rmse": 1.4180242556680394, "rmse": 1.3280577883458438,
"mae": 0.489678907149937, "mae": 0.5195405159964028,
"r2": 0.006066588761596381 "r2": 0.03195999739694366
}, },
"pdi": { "pdi": {
"n": 87, "n": 95,
"accuracy": 0.7011494252873564, "accuracy": 0.6105263157894737,
"precision": 0.382983682983683, "precision": 0.20350877192982456,
"recall": 0.3981244671781756, "recall": 0.3333333333333333,
"f1": 0.38831483607603007 "f1": 0.25272331154684097
}, },
"ee": { "ee": {
"n": 87, "n": 95,
"accuracy": 0.6206896551724138, "accuracy": 0.6736842105263158,
"precision": 0.4700393567388641, "precision": 0.22456140350877193,
"recall": 0.5084656084656085, "recall": 0.3333333333333333,
"f1": 0.4656627032698024 "f1": 0.26834381551362685
}, },
"toxic": { "toxic": {
"n": 65, "n": 66,
"accuracy": 0.9846153846153847, "accuracy": 0.8939393939393939,
"precision": 0.9918032786885246, "precision": 0.44696969696969696,
"recall": 0.9, "recall": 0.5,
"f1": 0.9403122130394859 "f1": 0.472
}, },
"biodist": { "biodist": {
"n": 65, "n": 66,
"kl_divergence": 0.20084619060054956, "kl_divergence": 0.851655784204727,
"js_divergence": 0.048292698388857934 "js_divergence": 0.21404831573756974
} }
}, },
{ {
"fold_idx": 1, "fold_idx": 1,
"n_samples": 87, "n_samples": 195,
"size": { "size": {
"n": 87, "n": 193,
"rmse": 0.40406802223294774, "rmse": 0.4425801645813746,
"mae": 0.2981993521767101, "mae": 0.26432527161632796,
"r2": 0.09717283856858072 "r2": -0.026225211870033682
}, },
"delivery": { "delivery": {
"n": 58, "n": 123,
"rmse": 0.545857341773909, "rmse": 0.7771322048436382,
"mae": 0.3866841504832023, "mae": 0.6133777339870822,
"r2": 0.39045172405838346 "r2": -0.128644776760948
}, },
"pdi": { "pdi": {
"n": 87, "n": 195,
"accuracy": 0.8045977011494253, "accuracy": 0.7076923076923077,
"precision": 0.49282452707110247, "precision": 0.35384615384615387,
"recall": 0.45567765567765567, "recall": 0.5,
"f1": 0.4661145617667357 "f1": 0.4144144144144144
}, },
"ee": { "ee": {
"n": 87, "n": 195,
"accuracy": 0.7241379310344828, "accuracy": 0.4205128205128205,
"precision": 0.7064586357039188, "precision": 0.14017094017094017,
"recall": 0.6503496503496503, "recall": 0.3333333333333333,
"f1": 0.6654761904761903 "f1": 0.19735258724428398
}, },
"toxic": { "toxic": {
"n": 59, "n": 123,
"accuracy": 0.9830508474576272, "accuracy": 1.0,
"precision": 0.9912280701754386, "precision": 1.0,
"recall": 0.8333333333333333, "recall": 1.0,
"f1": 0.8955752212389381 "f1": 1.0
}, },
"biodist": { "biodist": {
"n": 58, "n": 123,
"kl_divergence": 0.185822128176519, "kl_divergence": 0.9336461102028436,
"js_divergence": 0.049566546350752166 "js_divergence": 0.24870266224462317
} }
}, },
{ {
"fold_idx": 2, "fold_idx": 2,
"n_samples": 87, "n_samples": 51,
"size": { "size": {
"n": 86, "n": 51,
"rmse": 0.5861093745094258, "rmse": 0.6473513298834871,
"mae": 0.35274335949919944, "mae": 0.5600235602434944,
"r2": -0.3079648452189143 "r2": -9.27515642706235
}, },
"delivery": { "delivery": {
"n": 61, "n": 44,
"rmse": 0.5034529339588798, "rmse": 0.7721077356414991,
"mae": 0.3725305872618175, "mae": 0.6167582499593581,
"r2": 0.596618413667312 "r2": -0.4822886602727561
}, },
"pdi": { "pdi": {
"n": 87, "n": 51,
"accuracy": 0.7701149425287356, "accuracy": 0.8823529411764706,
"precision": 0.7266666666666666, "precision": 0.29411764705882354,
"recall": 0.6349206349206349, "recall": 0.3333333333333333,
"f1": 0.6497584541062802 "f1": 0.3125
}, },
"ee": { "ee": {
"n": 87, "n": 51,
"accuracy": 0.5172413793103449, "accuracy": 0.8431372549019608,
"precision": 0.43155828639699606, "precision": 0.28104575163398693,
"recall": 0.40000766812361016, "recall": 0.3333333333333333,
"f1": 0.3980599647266314 "f1": 0.3049645390070922
}, },
"toxic": { "toxic": {
"n": 61, "n": 47,
"accuracy": 1.0, "accuracy": 0.851063829787234,
"precision": 1.0, "precision": 0.425531914893617,
"recall": 1.0, "recall": 0.5,
"f1": 1.0 "f1": 0.4597701149425288
}, },
"biodist": { "biodist": {
"n": 61, "n": 45,
"kl_divergence": 0.2646404257700098, "kl_divergence": 1.1049896129018548,
"js_divergence": 0.07024299955112 "js_divergence": 0.25485248115851133
} }
}, },
{ {
"fold_idx": 3, "fold_idx": 3,
"n_samples": 86, "n_samples": 66,
"size": { "size": {
"n": 86, "n": 66,
"rmse": 0.32742961478246685, "rmse": 0.2407212117920812,
"mae": 0.25193805472795355, "mae": 0.19363613562150436,
"r2": -0.09589933555096875 "r2": -0.11204941379936861
}, },
"delivery": { "delivery": {
"n": 68, "n": 62,
"rmse": 0.7277366648519259, "rmse": 1.0041711455927012,
"mae": 0.42998586144462664, "mae": 0.7132550483914993,
"r2": 0.4053674615039361 "r2": -0.63265374674746
}, },
"pdi": { "pdi": {
"n": 86, "n": 66,
"accuracy": 0.7906976744186046, "accuracy": 0.8484848484848485,
"precision": 0.7513157894736842, "precision": 0.42424242424242425,
"recall": 0.6356534090909091, "recall": 0.5,
"f1": 0.6544642857142857 "f1": 0.4590163934426229
}, },
"ee": { "ee": {
"n": 86, "n": 66,
"accuracy": 0.7441860465116279, "accuracy": 0.8181818181818182,
"precision": 0.6962905144216474, "precision": 0.27692307692307694,
"recall": 0.6327243018419488, "recall": 0.32727272727272727,
"f1": 0.6557768628760061 "f1": 0.3
}, },
"toxic": { "toxic": {
"n": 68, "n": 62,
"accuracy": 1.0, "accuracy": 1.0,
"precision": 1.0, "precision": 1.0,
"recall": 1.0, "recall": 1.0,
"f1": 1.0 "f1": 1.0
}, },
"biodist": { "biodist": {
"n": 68, "n": 62,
"kl_divergence": 0.3411994760293469, "kl_divergence": 0.9677978984139058,
"js_divergence": 0.07812197338009717 "js_divergence": 0.2020309307244639
} }
}, },
{ {
"fold_idx": 4, "fold_idx": 4,
"n_samples": 87, "n_samples": 27,
"size": { "size": {
"n": 86, "n": 27,
"rmse": 0.28795189907825647, "rmse": 0.23392834445509142,
"mae": 0.21156823080639506, "mae": 0.19066280788845485,
"r2": 0.2077479731717109 "r2": -0.2667651950955112
}, },
"delivery": { "delivery": {
"n": 59, "n": 15,
"rmse": 0.8048179107025805, "rmse": 1.9603892288630869,
"mae": 0.5188011898327682, "mae": 1.3892907698949177,
"r2": 0.24048521206149798 "r2": -0.29760739742916287
}, },
"pdi": { "pdi": {
"n": 87, "n": 27,
"accuracy": 0.6896551724137931, "accuracy": 0.8888888888888888,
"precision": 0.4101075268817204, "precision": 0.4444444444444444,
"recall": 0.425, "recall": 0.5,
"f1": 0.4174194267871083 "f1": 0.47058823529411764
}, },
"ee": { "ee": {
"n": 87, "n": 27,
"accuracy": 0.7011494252873564, "accuracy": 0.5925925925925926,
"precision": 0.6600529100529101, "precision": 0.19753086419753085,
"recall": 0.581219806763285, "recall": 0.3333333333333333,
"f1": 0.5953238953238954 "f1": 0.24806201550387597
}, },
"toxic": { "toxic": {
"n": 60, "n": 15,
"accuracy": 0.95, "accuracy": 1.0,
"precision": 0.8240740740740741, "precision": 1.0,
"recall": 0.8818181818181818, "recall": 1.0,
"f1": 0.8498748957464553 "f1": 1.0
}, },
"biodist": { "biodist": {
"n": 59, "n": 15,
"kl_divergence": 0.207699088365002, "kl_divergence": 0.9389607012315264,
"js_divergence": 0.05288953180347253 "js_divergence": 0.2470218476598176
} }
} }
], ],
"summary_stats": { "summary_stats": {
"size": { "size": {
"rmse_mean": 0.43269138734524343, "rmse_mean": 0.43110039302375025,
"rmse_std": 0.12005218734930377, "rmse_std": 0.17179051271013462,
"r2_mean": -0.05249037976187758, "r2_mean": -1.9348966641330534,
"r2_std": 0.18417364026118202 "r2_std": 3.6713441784129
}, },
"delivery": { "delivery": {
"rmse_mean": 0.7999778213910669, "rmse_mean": 1.1683716206573538,
"rmse_std": 0.3285507063187239, "rmse_std": 0.4449374578352648,
"r2_mean": 0.32779788001054516, "r2_mean": -0.30184691676267666,
"r2_std": 0.19664259417310184 "r2_std": 0.23809090378746706
}, },
"pdi": { "pdi": {
"accuracy_mean": 0.751242983159583, "accuracy_mean": 0.7875890604063979,
"accuracy_std": 0.04703609766404652, "accuracy_std": 0.11016791908756088,
"f1_mean": 0.515214312890088, "f1_mean": 0.3818484709395992,
"f1_std": 0.11451705845950987 "f1_std": 0.08529090446864619
}, },
"ee": { "ee": {
"accuracy_mean": 0.6614808874632452, "accuracy_mean": 0.6696217393431015,
"accuracy_std": 0.08343692429542732, "accuracy_std": 0.15503740047242787,
"f1_mean": 0.5560599233345052, "f1_mean": 0.2637445914537758,
"f1_std": 0.10638861908930271 "f1_std": 0.039213602228007696
}, },
"toxic": { "toxic": {
"accuracy_mean": 0.9835332464146024, "accuracy_mean": 0.9490006447453256,
"accuracy_std": 0.018265761928954075, "accuracy_std": 0.06391582554207781,
"f1_mean": 0.9371524660049758, "f1_mean": 0.7863540229885058,
"f1_std": 0.05874632007606866 "f1_std": 0.26169039387919035
}, },
"biodist": { "biodist": {
"kl_mean": 0.24004146178828548, "kl_mean": 0.9594100213909715,
"kl_std": 0.05720155140134621, "kl_std": 0.08240959093662605,
"js_mean": 0.05982274989485996, "js_mean": 0.23333124750499712,
"js_std": 0.012080103435264247 "js_std": 0.021158533549255752
} }
}, },
"overall": { "overall": {
"size": { "size": {
"n_samples": 432, "n_samples": 432,
"mse": 0.20179930058631468, "mse": 0.22604480336185886,
"rmse": 0.44922077043065883, "rmse": 0.47544169291497657,
"mae": 0.2817203010673876, "mae": 0.3084443360567093,
"r2": -0.09923811531698323 "r2": -0.2313078534105617
}, },
"delivery": { "delivery": {
"n_samples": 310, "n_samples": 310,
"mse": 0.760202623547859, "mse": 1.0873755440675295,
"rmse": 0.8718959935381393, "rmse": 1.0427730069710903,
"mae": 0.4398058238289049, "mae": 0.6513989447841361,
"r2": 0.23486100707044422 "r2": -0.09443640807387799
}, },
"pdi": { "pdi": {
"n_samples": 434, "n_samples": 434,
"accuracy": 0.7511520737327189, "accuracy": 0.7396313364055299,
"precision": 0.3287852263755878, "precision": 0.18490783410138248,
"recall": 0.3184060228452752, "recall": 0.25,
"f1": 0.3212571677885814 "f1": 0.21258278145695364
}, },
"ee": { "ee": {
"n_samples": 434, "n_samples": 434,
"accuracy": 0.6612903225806451, "accuracy": 0.5967741935483871,
"precision": 0.5836247086247086, "precision": 0.1993841416474211,
"recall": 0.5469842657342657, "recall": 0.33205128205128204,
"f1": 0.5597637622559741 "f1": 0.24915824915824913
}, },
"toxic": { "toxic": {
"n_samples": 313, "n_samples": 313,
"accuracy": 0.9840255591054313, "accuracy": 0.9552715654952076,
"precision": 0.918076923076923, "precision": 0.4776357827476038,
"recall": 0.8895126612517916, "recall": 0.5,
"f1": 0.9032337847028998 "f1": 0.48856209150326796
}, },
"biodist": { "biodist": {
"n_samples": 311, "n_samples": 311,
"kl_divergence": 0.24254521665201, "kl_divergence": 0.9481034280166569,
"js_divergence": 0.06022985409160514 "js_divergence": 0.23285280825310384
} }
} }
} }

Binary file not shown.

View File

@ -1,293 +1,205 @@
{ {
"train": [ "train": [
{ {
"loss": 0.7867247516051271, "loss": 0.7730368412685099,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.6515523084370274, "loss": 0.658895703010919,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.5990842743185651, "loss": 0.6059015260392299,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.5633418128920326, "loss": 0.5744731174349416,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.5453761521296815, "loss": 0.5452056020458733,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.49953126002250825, "loss": 0.5138543470936083,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.49147369265204843, "loss": 0.4885380559178135,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4659397622863399, "loss": 0.47587182296687974,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4653009635305819, "loss": 0.4671051038255316,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4380375076610923, "loss": 0.46794115915756107,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4258159104875806, "loss": 0.4293930456997915,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4144523660948226, "loss": 0.42624105651716415,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.4008358244841981, "loss": 0.4131358770446828,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.40240038808127093, "loss": 0.3946074267790835,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.38176763174141226, "loss": 0.3898155013755344,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.37277743237904, "loss": 0.37861797005733383,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3573742728176747, "loss": 0.3775682858392304,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3491767022517619, "loss": 0.3800349080262064,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3499675623860557, "loss": 0.36302345173031675,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.35079578643841286, "loss": 0.3429561740842766,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.3414057292594471, "loss": 0.3445638883004898,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.300529963052257, "loss": 0.318970229203733,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.2961940902990875, "loss": 0.30179278279904437,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.2907693515383844, "loss": 0.2887343142006437,
"n_samples": 6783 "n_samples": 6783
}, },
{ {
"loss": 0.2809350616734551, "loss": 0.29240367556855545,
"n_samples": 6783
},
{
"loss": 0.28143580470326973,
"n_samples": 6783
},
{
"loss": 0.2664423378391215,
"n_samples": 6783
},
{
"loss": 0.2745858784487654,
"n_samples": 6783
},
{
"loss": 0.26682337215652197,
"n_samples": 6783
},
{
"loss": 0.2681302405486289,
"n_samples": 6783
},
{
"loss": 0.26258669999889017,
"n_samples": 6783
},
{
"loss": 0.2608744821883436,
"n_samples": 6783
},
{
"loss": 0.239722755447208,
"n_samples": 6783
},
{
"loss": 0.24175641130912484,
"n_samples": 6783
},
{
"loss": 0.23785491213674798,
"n_samples": 6783
},
{
"loss": 0.23117999019839675,
"n_samples": 6783 "n_samples": 6783
} }
], ],
"val": [ "val": [
{ {
"loss": 0.7379055186813953, "loss": 0.7350345371841441,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.7269540305477178, "loss": 0.7165568811318536,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6927794775152518, "loss": 0.7251406249862214,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6652627856533758, "loss": 0.6836505264587159,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6721626692594103, "loss": 0.6747132955771933,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6660812889787394, "loss": 0.6691136244936912,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6329412659009298, "loss": 0.6337480902323249,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6395346636332554, "loss": 0.6600317959527934,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6228037932749914, "loss": 0.6439923948855346,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.627341245329581, "loss": 0.643800035575267,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6399482499704272, "loss": 0.6181512585221839,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6136556721283145, "loss": 0.6442458634939151,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6253821217484764, "loss": 0.6344759362359862,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6535878175511408, "loss": 0.6501405371457472,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6123772015635566, "loss": 0.6098835162990152,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.648116178003258, "loss": 0.6366627322138894,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6141229696318092, "loss": 0.6171610150646417,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6307853255273552, "loss": 0.6358801012273748,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6422428293329848, "loss": 0.6239976831059871,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6552245357949421, "loss": 0.6683828232827201,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6614342853503823, "loss": 0.6655785786478143,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6153881246378465, "loss": 0.6152775046503088,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6298057977526632, "loss": 0.6202247662153858,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.6117005377321011, "loss": 0.648199727435189,
"n_samples": 2907 "n_samples": 2907
}, },
{ {
"loss": 0.618167482426702, "loss": 0.6473217075085124,
"n_samples": 2907
},
{
"loss": 0.6103124622220963,
"n_samples": 2907
},
{
"loss": 0.6203263100988888,
"n_samples": 2907
},
{
"loss": 0.617202088939065,
"n_samples": 2907
},
{
"loss": 0.6263537556599373,
"n_samples": 2907
},
{
"loss": 0.6461186489679168,
"n_samples": 2907
},
{
"loss": 0.6289772454163526,
"n_samples": 2907
},
{
"loss": 0.6347806630652919,
"n_samples": 2907
},
{
"loss": 0.6358688624302136,
"n_samples": 2907
},
{
"loss": 0.646931814478975,
"n_samples": 2907
},
{
"loss": 0.6245040218978556,
"n_samples": 2907
},
{
"loss": 0.6267098000482632,
"n_samples": 2907 "n_samples": 2907
} }
] ]

Binary file not shown.

Before

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 277 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 280 KiB