Compare commits

..

No commits in common. "93a6f8654d5d70984d9ff0da66fe3f5b4a7f9975" and "871afc5988504a2750d67f6ebf5b601a14efbc7c" have entirely different histories.

31 changed files with 2068 additions and 4166 deletions

View File

@ -14,7 +14,6 @@ import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
from lnp_ml.modeling.visualization import plot_loss_curves
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
@ -352,15 +351,6 @@ def main(
json.dump(result["history"], f, indent=2)
logger.success(f"Saved pretrain history to {history_path}")
# 绘制 loss 曲线图
loss_plot_path = output_dir / "pretrain_loss_curves.png"
plot_loss_curves(
history=result["history"],
output_path=loss_plot_path,
title="Pretrain Loss Curves (Delivery)",
)
logger.success(f"Saved loss curves plot to {loss_plot_path}")
logger.success(
f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}"
)

View File

@ -16,7 +16,6 @@ import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
from lnp_ml.modeling.visualization import plot_loss_curves
# MPNN ensemble 默认路径
@ -227,15 +226,6 @@ def train_fold(
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
# 绘制 loss 曲线图
loss_plot_path = fold_output_dir / "loss_curves.png"
plot_loss_curves(
history=history,
output_path=loss_plot_path,
title=f"Pretrain Fold {fold_idx} Loss Curves",
)
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
return {
"fold_idx": fold_idx,
"best_val_loss": best_val_loss,

View File

@ -19,7 +19,6 @@ from lnp_ml.modeling.trainer import (
EarlyStopping,
LossWeights,
)
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
@ -407,15 +406,6 @@ def main(
json.dump(result["history"], f, indent=2)
logger.success(f"Saved training history to {history_path}")
# 绘制多任务 loss 曲线图
loss_plot_path = output_dir / "loss_curves.png"
plot_multitask_loss_curves(
history=result["history"],
output_path=loss_plot_path,
title="Multi-task Training Loss Curves",
)
logger.success(f"Saved loss curves plot to {loss_plot_path}")
logger.success(f"Training complete! Best val_loss: {result['best_val_loss']:.4f}")

View File

@ -22,7 +22,6 @@ from lnp_ml.modeling.trainer import (
EarlyStopping,
LossWeights,
)
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
# MPNN ensemble 默认路径
@ -159,15 +158,6 @@ def train_fold(
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
# 绘制多任务 loss 曲线图
loss_plot_path = fold_output_dir / "loss_curves.png"
plot_multitask_loss_curves(
history=history,
output_path=loss_plot_path,
title=f"Fold {fold_idx} Multi-task Loss Curves",
)
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
return {
"fold_idx": fold_idx,
"best_val_loss": best_val_loss,

View File

@ -19,12 +19,6 @@ class LossWeights:
delivery: float = 1.0
biodist: float = 1.0
toxic: float = 1.0
# size: float = 0.1
# pdi: float = 0.3
# ee: float = 0.3
# delivery: float = 1.0
# biodist: float = 1.0
# toxic: float = 0.05
def compute_multitask_loss(

View File

@ -1,284 +0,0 @@
"""训练过程可视化工具:绘制 loss 曲线"""
from pathlib import Path
from typing import Dict, List, Optional, Union
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端,避免 GUI 依赖
# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def plot_loss_curves(
history: Union[Dict[str, List[Dict]], List[Dict]],
output_path: Path,
title: str = "Training Loss Curves",
figsize: tuple = (12, 8),
) -> None:
"""
绘制训练过程中各个 loss 组成部分的变化曲线
支持两种 history 格式
1. 预训练格式单任务
{"train": [{"loss": 0.1, ...}, ...], "val": [{"loss": 0.1, ...}, ...]}
2. CV fold 格式
[{"epoch": 1, "train_loss": 0.1, "val_loss": 0.1, ...}, ...]
Args:
history: 训练历史记录
output_path: 输出图片路径
title: 图标题
figsize: 图片尺寸
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 判断 history 格式并统一处理
if isinstance(history, dict) and "train" in history:
# 格式1: {"train": [...], "val": [...]}
_plot_standard_history(history, output_path, title, figsize)
elif isinstance(history, list):
# 格式2: [{...}, {...}, ...]
_plot_flat_history(history, output_path, title, figsize)
else:
raise ValueError(f"Unsupported history format: {type(history)}")
def _plot_standard_history(
history: Dict[str, List[Dict]],
output_path: Path,
title: str,
figsize: tuple,
) -> None:
"""绘制标准格式的 historytrain/val 分开)"""
train_history = history.get("train", [])
val_history = history.get("val", [])
if not train_history:
return
epochs = list(range(1, len(train_history) + 1))
# 收集所有 loss 键
all_loss_keys = set()
for record in train_history + val_history:
for key in record.keys():
if key.startswith("loss") or key == "loss":
all_loss_keys.add(key)
# 分离总 loss 和各任务 loss
total_loss_key = "loss"
task_loss_keys = sorted([k for k in all_loss_keys if k != "loss"])
# 创建子图
n_subplots = 1 + (1 if task_loss_keys else 0)
fig, axes = plt.subplots(n_subplots, 1, figsize=(figsize[0], figsize[1] * n_subplots / 2))
if n_subplots == 1:
axes = [axes]
# 颜色配置
colors = plt.cm.tab10.colors
# 子图1总 loss
ax = axes[0]
train_total_loss = [r.get(total_loss_key, 0) for r in train_history]
val_total_loss = [r.get(total_loss_key, 0) for r in val_history]
ax.plot(epochs, train_total_loss, 'o-', label='Train Total Loss', color=colors[0], linewidth=2, markersize=4)
if val_total_loss:
ax.plot(epochs, val_total_loss, 's--', label='Val Total Loss', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(f'{title} - Total Loss', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
# 子图2各任务 loss如果有
if task_loss_keys:
ax = axes[1]
for i, key in enumerate(task_loss_keys):
task_name = key.replace("loss_", "").upper()
train_values = [r.get(key, 0) for r in train_history]
val_values = [r.get(key, 0) for r in val_history]
color = colors[i % len(colors)]
ax.plot(epochs, train_values, 'o-', label=f'Train {task_name}', color=color, alpha=0.8, linewidth=1.5, markersize=3)
if val_values and any(v > 0 for v in val_values):
ax.plot(epochs, val_values, 's--', label=f'Val {task_name}', color=color, alpha=0.5, linewidth=1.5, markersize=3)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(f'{title} - Per-Task Loss', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=9, ncol=2)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)
def _plot_flat_history(
history: List[Dict],
output_path: Path,
title: str,
figsize: tuple,
) -> None:
"""绘制扁平格式的 historyCV fold 格式)"""
if not history:
return
epochs = [r.get("epoch", i + 1) for i, r in enumerate(history)]
# 收集所有 loss 相关的键
loss_keys = set()
for record in history:
for key in record.keys():
if "loss" in key.lower():
loss_keys.add(key)
# 分类
train_keys = sorted([k for k in loss_keys if "train" in k.lower()])
val_keys = sorted([k for k in loss_keys if "val" in k.lower()])
# 创建子图
fig, ax = plt.subplots(1, 1, figsize=figsize)
colors = plt.cm.tab10.colors
color_idx = 0
# 绘制训练 loss
for key in train_keys:
values = [r.get(key, 0) for r in history]
label = key.replace("_", " ").title()
ax.plot(epochs, values, 'o-', label=label, color=colors[color_idx % len(colors)],
linewidth=2, markersize=4, alpha=0.9)
color_idx += 1
# 绘制验证 loss
for key in val_keys:
values = [r.get(key, 0) for r in history]
label = key.replace("_", " ").title()
ax.plot(epochs, values, 's--', label=label, color=colors[color_idx % len(colors)],
linewidth=2, markersize=4, alpha=0.7)
color_idx += 1
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0.5, len(epochs) + 0.5)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)
def plot_multitask_loss_curves(
history: Dict[str, List[Dict]],
output_path: Path,
title: str = "Multi-task Training Loss",
figsize: tuple = (14, 10),
) -> None:
"""
专门用于多任务训练的 loss 曲线绘制
将各个任务的 loss 分别绘制在不同的子图中便于比较
Args:
history: {"train": [...], "val": [...]} 格式的训练历史
output_path: 输出路径
title: 图标题
figsize: 图片尺寸
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
train_history = history.get("train", [])
val_history = history.get("val", [])
if not train_history:
return
epochs = list(range(1, len(train_history) + 1))
# 提取所有任务的 loss 键
task_keys = set()
for record in train_history:
for key in record.keys():
if key.startswith("loss_"):
task_name = key.replace("loss_", "")
task_keys.add(task_name)
task_keys = sorted(task_keys)
# 计算子图布局
n_tasks = len(task_keys)
if n_tasks == 0:
# 只有总 loss使用简单绘图
_plot_standard_history(history, output_path, title, figsize)
return
# 包含总 loss共 n_tasks + 1 个子图
n_plots = n_tasks + 1
n_cols = min(3, n_plots)
n_rows = (n_plots + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(figsize[0], figsize[1] * n_rows / 2))
if n_plots == 1:
axes = [[axes]]
elif n_rows == 1:
axes = [axes]
axes_flat = [ax for row in axes for ax in (row if hasattr(row, '__iter__') else [row])]
colors = plt.cm.tab10.colors
# 子图1总 loss
ax = axes_flat[0]
train_total = [r.get("loss", 0) for r in train_history]
val_total = [r.get("loss", 0) for r in val_history]
ax.plot(epochs, train_total, 'o-', label='Train', color=colors[0], linewidth=2, markersize=4)
if val_total:
ax.plot(epochs, val_total, 's--', label='Val', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Total Loss', fontweight='bold')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
# 各任务子图
for idx, task in enumerate(task_keys):
ax = axes_flat[idx + 1]
key = f"loss_{task}"
train_values = [r.get(key, 0) for r in train_history]
val_values = [r.get(key, 0) for r in val_history]
# 只绘制有值的数据
if any(v > 0 for v in train_values):
ax.plot(epochs, train_values, 'o-', label='Train', color=colors[0], linewidth=2, markersize=4)
if val_values and any(v > 0 for v in val_values):
ax.plot(epochs, val_values, 's--', label='Val', color=colors[1], linewidth=2, markersize=4)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title(f'{task.upper()} Loss', fontweight='bold')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
# 隐藏多余的子图
for idx in range(n_plots, len(axes_flat)):
axes_flat[idx].set_visible(False)
plt.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)

View File

@ -11,6 +11,6 @@
"batch_size": 32,
"epochs": 100,
"patience": 15,
"init_from_pretrain": "models/pretrain_delivery.pt",
"init_from_pretrain": null,
"freeze_backbone": false
}

View File

@ -2,38 +2,38 @@
"fold_results": [
{
"fold_idx": 0,
"best_val_loss": 0.9860520362854004,
"epochs_trained": 40,
"final_train_loss": 0.5008097920152876
"best_val_loss": 5.7676777839660645,
"epochs_trained": 24,
"final_train_loss": 1.4942118644714355
},
{
"fold_idx": 1,
"best_val_loss": 2.4599782625834146,
"epochs_trained": 38,
"final_train_loss": 0.564177993271086
"best_val_loss": 8.418675899505615,
"epochs_trained": 20,
"final_train_loss": 1.4902493238449097
},
{
"fold_idx": 2,
"best_val_loss": 0.7660132050514221,
"epochs_trained": 43,
"final_train_loss": 0.6722757054699792
"best_val_loss": 3.5122547830854143,
"epochs_trained": 25,
"final_train_loss": 1.7609570423762004
},
{
"fold_idx": 3,
"best_val_loss": 1.065057098865509,
"epochs_trained": 31,
"final_train_loss": 0.7323974437183804
"best_val_loss": 3.165306806564331,
"epochs_trained": 21,
"final_train_loss": 2.0073827385902403
},
{
"fold_idx": 4,
"best_val_loss": 1.321769932905833,
"epochs_trained": 36,
"final_train_loss": 0.5991987817817264
"best_val_loss": 2.996154228846232,
"epochs_trained": 18,
"final_train_loss": 1.9732873006300493
}
],
"summary": {
"val_loss_mean": 1.3197741071383158,
"val_loss_std": 0.5971552245392587
"val_loss_mean": 4.772013900393532,
"val_loss_std": 2.0790222989111475
},
"config": {
"d_model": 256,
@ -48,7 +48,7 @@
"batch_size": 32,
"epochs": 100,
"patience": 15,
"init_from_pretrain": "models/pretrain_delivery.pt",
"init_from_pretrain": null,
"freeze_backbone": false
}
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 287 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 270 KiB

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

View File

@ -2,293 +2,293 @@
"fold_results": [
{
"fold_idx": 0,
"n_samples": 87,
"n_samples": 95,
"size": {
"n": 87,
"rmse": 0.5578980261231201,
"mae": 0.293820194814397,
"r2": -0.16350852977979646
"n": 95,
"rmse": 0.5909209144067168,
"mae": 0.376253614927593,
"r2": 0.005712927161997228
},
"delivery": {
"n": 64,
"rmse": 1.4180242556680394,
"mae": 0.489678907149937,
"r2": 0.006066588761596381
"n": 66,
"rmse": 1.3280577883458438,
"mae": 0.5195405159964028,
"r2": 0.03195999739694366
},
"pdi": {
"n": 87,
"accuracy": 0.7011494252873564,
"precision": 0.382983682983683,
"recall": 0.3981244671781756,
"f1": 0.38831483607603007
"n": 95,
"accuracy": 0.6105263157894737,
"precision": 0.20350877192982456,
"recall": 0.3333333333333333,
"f1": 0.25272331154684097
},
"ee": {
"n": 87,
"accuracy": 0.6206896551724138,
"precision": 0.4700393567388641,
"recall": 0.5084656084656085,
"f1": 0.4656627032698024
"n": 95,
"accuracy": 0.6736842105263158,
"precision": 0.22456140350877193,
"recall": 0.3333333333333333,
"f1": 0.26834381551362685
},
"toxic": {
"n": 65,
"accuracy": 0.9846153846153847,
"precision": 0.9918032786885246,
"recall": 0.9,
"f1": 0.9403122130394859
"n": 66,
"accuracy": 0.8939393939393939,
"precision": 0.44696969696969696,
"recall": 0.5,
"f1": 0.472
},
"biodist": {
"n": 65,
"kl_divergence": 0.20084619060054956,
"js_divergence": 0.048292698388857934
"n": 66,
"kl_divergence": 0.851655784204727,
"js_divergence": 0.21404831573756974
}
},
{
"fold_idx": 1,
"n_samples": 87,
"n_samples": 195,
"size": {
"n": 87,
"rmse": 0.40406802223294774,
"mae": 0.2981993521767101,
"r2": 0.09717283856858072
"n": 193,
"rmse": 0.4425801645813746,
"mae": 0.26432527161632796,
"r2": -0.026225211870033682
},
"delivery": {
"n": 58,
"rmse": 0.545857341773909,
"mae": 0.3866841504832023,
"r2": 0.39045172405838346
"n": 123,
"rmse": 0.7771322048436382,
"mae": 0.6133777339870822,
"r2": -0.128644776760948
},
"pdi": {
"n": 87,
"accuracy": 0.8045977011494253,
"precision": 0.49282452707110247,
"recall": 0.45567765567765567,
"f1": 0.4661145617667357
"n": 195,
"accuracy": 0.7076923076923077,
"precision": 0.35384615384615387,
"recall": 0.5,
"f1": 0.4144144144144144
},
"ee": {
"n": 87,
"accuracy": 0.7241379310344828,
"precision": 0.7064586357039188,
"recall": 0.6503496503496503,
"f1": 0.6654761904761903
"n": 195,
"accuracy": 0.4205128205128205,
"precision": 0.14017094017094017,
"recall": 0.3333333333333333,
"f1": 0.19735258724428398
},
"toxic": {
"n": 59,
"accuracy": 0.9830508474576272,
"precision": 0.9912280701754386,
"recall": 0.8333333333333333,
"f1": 0.8955752212389381
"n": 123,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 58,
"kl_divergence": 0.185822128176519,
"js_divergence": 0.049566546350752166
"n": 123,
"kl_divergence": 0.9336461102028436,
"js_divergence": 0.24870266224462317
}
},
{
"fold_idx": 2,
"n_samples": 87,
"n_samples": 51,
"size": {
"n": 86,
"rmse": 0.5861093745094258,
"mae": 0.35274335949919944,
"r2": -0.3079648452189143
"n": 51,
"rmse": 0.6473513298834871,
"mae": 0.5600235602434944,
"r2": -9.27515642706235
},
"delivery": {
"n": 61,
"rmse": 0.5034529339588798,
"mae": 0.3725305872618175,
"r2": 0.596618413667312
"n": 44,
"rmse": 0.7721077356414991,
"mae": 0.6167582499593581,
"r2": -0.4822886602727561
},
"pdi": {
"n": 87,
"accuracy": 0.7701149425287356,
"precision": 0.7266666666666666,
"recall": 0.6349206349206349,
"f1": 0.6497584541062802
"n": 51,
"accuracy": 0.8823529411764706,
"precision": 0.29411764705882354,
"recall": 0.3333333333333333,
"f1": 0.3125
},
"ee": {
"n": 87,
"accuracy": 0.5172413793103449,
"precision": 0.43155828639699606,
"recall": 0.40000766812361016,
"f1": 0.3980599647266314
"n": 51,
"accuracy": 0.8431372549019608,
"precision": 0.28104575163398693,
"recall": 0.3333333333333333,
"f1": 0.3049645390070922
},
"toxic": {
"n": 61,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
"n": 47,
"accuracy": 0.851063829787234,
"precision": 0.425531914893617,
"recall": 0.5,
"f1": 0.4597701149425288
},
"biodist": {
"n": 61,
"kl_divergence": 0.2646404257700098,
"js_divergence": 0.07024299955112
"n": 45,
"kl_divergence": 1.1049896129018548,
"js_divergence": 0.25485248115851133
}
},
{
"fold_idx": 3,
"n_samples": 86,
"n_samples": 66,
"size": {
"n": 86,
"rmse": 0.32742961478246685,
"mae": 0.25193805472795355,
"r2": -0.09589933555096875
"n": 66,
"rmse": 0.2407212117920812,
"mae": 0.19363613562150436,
"r2": -0.11204941379936861
},
"delivery": {
"n": 68,
"rmse": 0.7277366648519259,
"mae": 0.42998586144462664,
"r2": 0.4053674615039361
"n": 62,
"rmse": 1.0041711455927012,
"mae": 0.7132550483914993,
"r2": -0.63265374674746
},
"pdi": {
"n": 86,
"accuracy": 0.7906976744186046,
"precision": 0.7513157894736842,
"recall": 0.6356534090909091,
"f1": 0.6544642857142857
"n": 66,
"accuracy": 0.8484848484848485,
"precision": 0.42424242424242425,
"recall": 0.5,
"f1": 0.4590163934426229
},
"ee": {
"n": 86,
"accuracy": 0.7441860465116279,
"precision": 0.6962905144216474,
"recall": 0.6327243018419488,
"f1": 0.6557768628760061
"n": 66,
"accuracy": 0.8181818181818182,
"precision": 0.27692307692307694,
"recall": 0.32727272727272727,
"f1": 0.3
},
"toxic": {
"n": 68,
"n": 62,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 68,
"kl_divergence": 0.3411994760293469,
"js_divergence": 0.07812197338009717
"n": 62,
"kl_divergence": 0.9677978984139058,
"js_divergence": 0.2020309307244639
}
},
{
"fold_idx": 4,
"n_samples": 87,
"n_samples": 27,
"size": {
"n": 86,
"rmse": 0.28795189907825647,
"mae": 0.21156823080639506,
"r2": 0.2077479731717109
"n": 27,
"rmse": 0.23392834445509142,
"mae": 0.19066280788845485,
"r2": -0.2667651950955112
},
"delivery": {
"n": 59,
"rmse": 0.8048179107025805,
"mae": 0.5188011898327682,
"r2": 0.24048521206149798
"n": 15,
"rmse": 1.9603892288630869,
"mae": 1.3892907698949177,
"r2": -0.29760739742916287
},
"pdi": {
"n": 87,
"accuracy": 0.6896551724137931,
"precision": 0.4101075268817204,
"recall": 0.425,
"f1": 0.4174194267871083
"n": 27,
"accuracy": 0.8888888888888888,
"precision": 0.4444444444444444,
"recall": 0.5,
"f1": 0.47058823529411764
},
"ee": {
"n": 87,
"accuracy": 0.7011494252873564,
"precision": 0.6600529100529101,
"recall": 0.581219806763285,
"f1": 0.5953238953238954
"n": 27,
"accuracy": 0.5925925925925926,
"precision": 0.19753086419753085,
"recall": 0.3333333333333333,
"f1": 0.24806201550387597
},
"toxic": {
"n": 60,
"accuracy": 0.95,
"precision": 0.8240740740740741,
"recall": 0.8818181818181818,
"f1": 0.8498748957464553
"n": 15,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 59,
"kl_divergence": 0.207699088365002,
"js_divergence": 0.05288953180347253
"n": 15,
"kl_divergence": 0.9389607012315264,
"js_divergence": 0.2470218476598176
}
}
],
"summary_stats": {
"size": {
"rmse_mean": 0.43269138734524343,
"rmse_std": 0.12005218734930377,
"r2_mean": -0.05249037976187758,
"r2_std": 0.18417364026118202
"rmse_mean": 0.43110039302375025,
"rmse_std": 0.17179051271013462,
"r2_mean": -1.9348966641330534,
"r2_std": 3.6713441784129
},
"delivery": {
"rmse_mean": 0.7999778213910669,
"rmse_std": 0.3285507063187239,
"r2_mean": 0.32779788001054516,
"r2_std": 0.19664259417310184
"rmse_mean": 1.1683716206573538,
"rmse_std": 0.4449374578352648,
"r2_mean": -0.30184691676267666,
"r2_std": 0.23809090378746706
},
"pdi": {
"accuracy_mean": 0.751242983159583,
"accuracy_std": 0.04703609766404652,
"f1_mean": 0.515214312890088,
"f1_std": 0.11451705845950987
"accuracy_mean": 0.7875890604063979,
"accuracy_std": 0.11016791908756088,
"f1_mean": 0.3818484709395992,
"f1_std": 0.08529090446864619
},
"ee": {
"accuracy_mean": 0.6614808874632452,
"accuracy_std": 0.08343692429542732,
"f1_mean": 0.5560599233345052,
"f1_std": 0.10638861908930271
"accuracy_mean": 0.6696217393431015,
"accuracy_std": 0.15503740047242787,
"f1_mean": 0.2637445914537758,
"f1_std": 0.039213602228007696
},
"toxic": {
"accuracy_mean": 0.9835332464146024,
"accuracy_std": 0.018265761928954075,
"f1_mean": 0.9371524660049758,
"f1_std": 0.05874632007606866
"accuracy_mean": 0.9490006447453256,
"accuracy_std": 0.06391582554207781,
"f1_mean": 0.7863540229885058,
"f1_std": 0.26169039387919035
},
"biodist": {
"kl_mean": 0.24004146178828548,
"kl_std": 0.05720155140134621,
"js_mean": 0.05982274989485996,
"js_std": 0.012080103435264247
"kl_mean": 0.9594100213909715,
"kl_std": 0.08240959093662605,
"js_mean": 0.23333124750499712,
"js_std": 0.021158533549255752
}
},
"overall": {
"size": {
"n_samples": 432,
"mse": 0.20179930058631468,
"rmse": 0.44922077043065883,
"mae": 0.2817203010673876,
"r2": -0.09923811531698323
"mse": 0.22604480336185886,
"rmse": 0.47544169291497657,
"mae": 0.3084443360567093,
"r2": -0.2313078534105617
},
"delivery": {
"n_samples": 310,
"mse": 0.760202623547859,
"rmse": 0.8718959935381393,
"mae": 0.4398058238289049,
"r2": 0.23486100707044422
"mse": 1.0873755440675295,
"rmse": 1.0427730069710903,
"mae": 0.6513989447841361,
"r2": -0.09443640807387799
},
"pdi": {
"n_samples": 434,
"accuracy": 0.7511520737327189,
"precision": 0.3287852263755878,
"recall": 0.3184060228452752,
"f1": 0.3212571677885814
"accuracy": 0.7396313364055299,
"precision": 0.18490783410138248,
"recall": 0.25,
"f1": 0.21258278145695364
},
"ee": {
"n_samples": 434,
"accuracy": 0.6612903225806451,
"precision": 0.5836247086247086,
"recall": 0.5469842657342657,
"f1": 0.5597637622559741
"accuracy": 0.5967741935483871,
"precision": 0.1993841416474211,
"recall": 0.33205128205128204,
"f1": 0.24915824915824913
},
"toxic": {
"n_samples": 313,
"accuracy": 0.9840255591054313,
"precision": 0.918076923076923,
"recall": 0.8895126612517916,
"f1": 0.9032337847028998
"accuracy": 0.9552715654952076,
"precision": 0.4776357827476038,
"recall": 0.5,
"f1": 0.48856209150326796
},
"biodist": {
"n_samples": 311,
"kl_divergence": 0.24254521665201,
"js_divergence": 0.06022985409160514
"kl_divergence": 0.9481034280166569,
"js_divergence": 0.23285280825310384
}
}
}

Binary file not shown.

View File

@ -1,293 +1,205 @@
{
"train": [
{
"loss": 0.7867247516051271,
"loss": 0.7730368412685099,
"n_samples": 6783
},
{
"loss": 0.6515523084370274,
"loss": 0.658895703010919,
"n_samples": 6783
},
{
"loss": 0.5990842743185651,
"loss": 0.6059015260392299,
"n_samples": 6783
},
{
"loss": 0.5633418128920326,
"loss": 0.5744731174349416,
"n_samples": 6783
},
{
"loss": 0.5453761521296815,
"loss": 0.5452056020458733,
"n_samples": 6783
},
{
"loss": 0.49953126002250825,
"loss": 0.5138543470936083,
"n_samples": 6783
},
{
"loss": 0.49147369265204843,
"loss": 0.4885380559178135,
"n_samples": 6783
},
{
"loss": 0.4659397622863399,
"loss": 0.47587182296687974,
"n_samples": 6783
},
{
"loss": 0.4653009635305819,
"loss": 0.4671051038255316,
"n_samples": 6783
},
{
"loss": 0.4380375076610923,
"loss": 0.46794115915756107,
"n_samples": 6783
},
{
"loss": 0.4258159104875806,
"loss": 0.4293930456997915,
"n_samples": 6783
},
{
"loss": 0.4144523660948226,
"loss": 0.42624105651716415,
"n_samples": 6783
},
{
"loss": 0.4008358244841981,
"loss": 0.4131358770446828,
"n_samples": 6783
},
{
"loss": 0.40240038808127093,
"loss": 0.3946074267790835,
"n_samples": 6783
},
{
"loss": 0.38176763174141226,
"loss": 0.3898155013755344,
"n_samples": 6783
},
{
"loss": 0.37277743237904,
"loss": 0.37861797005733383,
"n_samples": 6783
},
{
"loss": 0.3573742728176747,
"loss": 0.3775682858392304,
"n_samples": 6783
},
{
"loss": 0.3491767022517619,
"loss": 0.3800349080262064,
"n_samples": 6783
},
{
"loss": 0.3499675623860557,
"loss": 0.36302345173031675,
"n_samples": 6783
},
{
"loss": 0.35079578643841286,
"loss": 0.3429561740842766,
"n_samples": 6783
},
{
"loss": 0.3414057292594471,
"loss": 0.3445638883004898,
"n_samples": 6783
},
{
"loss": 0.300529963052257,
"loss": 0.318970229203733,
"n_samples": 6783
},
{
"loss": 0.2961940902990875,
"loss": 0.30179278279904437,
"n_samples": 6783
},
{
"loss": 0.2907693515383844,
"loss": 0.2887343142006437,
"n_samples": 6783
},
{
"loss": 0.2809350616734551,
"n_samples": 6783
},
{
"loss": 0.28143580470326973,
"n_samples": 6783
},
{
"loss": 0.2664423378391215,
"n_samples": 6783
},
{
"loss": 0.2745858784487654,
"n_samples": 6783
},
{
"loss": 0.26682337215652197,
"n_samples": 6783
},
{
"loss": 0.2681302405486289,
"n_samples": 6783
},
{
"loss": 0.26258669999889017,
"n_samples": 6783
},
{
"loss": 0.2608744821883436,
"n_samples": 6783
},
{
"loss": 0.239722755447208,
"n_samples": 6783
},
{
"loss": 0.24175641130912484,
"n_samples": 6783
},
{
"loss": 0.23785491213674798,
"n_samples": 6783
},
{
"loss": 0.23117999019839675,
"loss": 0.29240367556855545,
"n_samples": 6783
}
],
"val": [
{
"loss": 0.7379055186813953,
"loss": 0.7350345371841441,
"n_samples": 2907
},
{
"loss": 0.7269540305477178,
"loss": 0.7165568811318536,
"n_samples": 2907
},
{
"loss": 0.6927794775152518,
"loss": 0.7251406249862214,
"n_samples": 2907
},
{
"loss": 0.6652627856533758,
"loss": 0.6836505264587159,
"n_samples": 2907
},
{
"loss": 0.6721626692594103,
"loss": 0.6747132955771933,
"n_samples": 2907
},
{
"loss": 0.6660812889787394,
"loss": 0.6691136244936912,
"n_samples": 2907
},
{
"loss": 0.6329412659009298,
"loss": 0.6337480902323249,
"n_samples": 2907
},
{
"loss": 0.6395346636332554,
"loss": 0.6600317959527934,
"n_samples": 2907
},
{
"loss": 0.6228037932749914,
"loss": 0.6439923948855346,
"n_samples": 2907
},
{
"loss": 0.627341245329581,
"loss": 0.643800035575267,
"n_samples": 2907
},
{
"loss": 0.6399482499704272,
"loss": 0.6181512585221839,
"n_samples": 2907
},
{
"loss": 0.6136556721283145,
"loss": 0.6442458634939151,
"n_samples": 2907
},
{
"loss": 0.6253821217484764,
"loss": 0.6344759362359862,
"n_samples": 2907
},
{
"loss": 0.6535878175511408,
"loss": 0.6501405371457472,
"n_samples": 2907
},
{
"loss": 0.6123772015635566,
"loss": 0.6098835162990152,
"n_samples": 2907
},
{
"loss": 0.648116178003258,
"loss": 0.6366627322138894,
"n_samples": 2907
},
{
"loss": 0.6141229696318092,
"loss": 0.6171610150646417,
"n_samples": 2907
},
{
"loss": 0.6307853255273552,
"loss": 0.6358801012273748,
"n_samples": 2907
},
{
"loss": 0.6422428293329848,
"loss": 0.6239976831059871,
"n_samples": 2907
},
{
"loss": 0.6552245357949421,
"loss": 0.6683828232827201,
"n_samples": 2907
},
{
"loss": 0.6614342853503823,
"loss": 0.6655785786478143,
"n_samples": 2907
},
{
"loss": 0.6153881246378465,
"loss": 0.6152775046503088,
"n_samples": 2907
},
{
"loss": 0.6298057977526632,
"loss": 0.6202247662153858,
"n_samples": 2907
},
{
"loss": 0.6117005377321011,
"loss": 0.648199727435189,
"n_samples": 2907
},
{
"loss": 0.618167482426702,
"n_samples": 2907
},
{
"loss": 0.6103124622220963,
"n_samples": 2907
},
{
"loss": 0.6203263100988888,
"n_samples": 2907
},
{
"loss": 0.617202088939065,
"n_samples": 2907
},
{
"loss": 0.6263537556599373,
"n_samples": 2907
},
{
"loss": 0.6461186489679168,
"n_samples": 2907
},
{
"loss": 0.6289772454163526,
"n_samples": 2907
},
{
"loss": 0.6347806630652919,
"n_samples": 2907
},
{
"loss": 0.6358688624302136,
"n_samples": 2907
},
{
"loss": 0.646931814478975,
"n_samples": 2907
},
{
"loss": 0.6245040218978556,
"n_samples": 2907
},
{
"loss": 0.6267098000482632,
"loss": 0.6473217075085124,
"n_samples": 2907
}
]

Binary file not shown.

Before

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 274 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 277 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 280 KiB