Add loss visualization
@ -14,6 +14,7 @@ import typer
|
|||||||
|
|
||||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||||
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
|
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
|
||||||
|
from lnp_ml.modeling.visualization import plot_loss_curves
|
||||||
|
|
||||||
# MPNN ensemble 默认路径
|
# MPNN ensemble 默认路径
|
||||||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||||
@ -351,6 +352,15 @@ def main(
|
|||||||
json.dump(result["history"], f, indent=2)
|
json.dump(result["history"], f, indent=2)
|
||||||
logger.success(f"Saved pretrain history to {history_path}")
|
logger.success(f"Saved pretrain history to {history_path}")
|
||||||
|
|
||||||
|
# 绘制 loss 曲线图
|
||||||
|
loss_plot_path = output_dir / "pretrain_loss_curves.png"
|
||||||
|
plot_loss_curves(
|
||||||
|
history=result["history"],
|
||||||
|
output_path=loss_plot_path,
|
||||||
|
title="Pretrain Loss Curves (Delivery)",
|
||||||
|
)
|
||||||
|
logger.success(f"Saved loss curves plot to {loss_plot_path}")
|
||||||
|
|
||||||
logger.success(
|
logger.success(
|
||||||
f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}"
|
f"Pretraining complete! Best val_loss: {result['best_val_loss']:.4f}"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import typer
|
|||||||
|
|
||||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||||
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
|
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
|
||||||
|
from lnp_ml.modeling.visualization import plot_loss_curves
|
||||||
|
|
||||||
|
|
||||||
# MPNN ensemble 默认路径
|
# MPNN ensemble 默认路径
|
||||||
@ -226,6 +227,15 @@ def train_fold(
|
|||||||
with open(history_path, "w") as f:
|
with open(history_path, "w") as f:
|
||||||
json.dump(history, f, indent=2)
|
json.dump(history, f, indent=2)
|
||||||
|
|
||||||
|
# 绘制 loss 曲线图
|
||||||
|
loss_plot_path = fold_output_dir / "loss_curves.png"
|
||||||
|
plot_loss_curves(
|
||||||
|
history=history,
|
||||||
|
output_path=loss_plot_path,
|
||||||
|
title=f"Pretrain Fold {fold_idx} Loss Curves",
|
||||||
|
)
|
||||||
|
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"fold_idx": fold_idx,
|
"fold_idx": fold_idx,
|
||||||
"best_val_loss": best_val_loss,
|
"best_val_loss": best_val_loss,
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from lnp_ml.modeling.trainer import (
|
|||||||
EarlyStopping,
|
EarlyStopping,
|
||||||
LossWeights,
|
LossWeights,
|
||||||
)
|
)
|
||||||
|
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
|
||||||
|
|
||||||
# MPNN ensemble 默认路径
|
# MPNN ensemble 默认路径
|
||||||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||||
@ -406,6 +407,15 @@ def main(
|
|||||||
json.dump(result["history"], f, indent=2)
|
json.dump(result["history"], f, indent=2)
|
||||||
logger.success(f"Saved training history to {history_path}")
|
logger.success(f"Saved training history to {history_path}")
|
||||||
|
|
||||||
|
# 绘制多任务 loss 曲线图
|
||||||
|
loss_plot_path = output_dir / "loss_curves.png"
|
||||||
|
plot_multitask_loss_curves(
|
||||||
|
history=result["history"],
|
||||||
|
output_path=loss_plot_path,
|
||||||
|
title="Multi-task Training Loss Curves",
|
||||||
|
)
|
||||||
|
logger.success(f"Saved loss curves plot to {loss_plot_path}")
|
||||||
|
|
||||||
logger.success(f"Training complete! Best val_loss: {result['best_val_loss']:.4f}")
|
logger.success(f"Training complete! Best val_loss: {result['best_val_loss']:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from lnp_ml.modeling.trainer import (
|
|||||||
EarlyStopping,
|
EarlyStopping,
|
||||||
LossWeights,
|
LossWeights,
|
||||||
)
|
)
|
||||||
|
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
|
||||||
|
|
||||||
|
|
||||||
# MPNN ensemble 默认路径
|
# MPNN ensemble 默认路径
|
||||||
@ -158,6 +159,15 @@ def train_fold(
|
|||||||
with open(history_path, "w") as f:
|
with open(history_path, "w") as f:
|
||||||
json.dump(history, f, indent=2)
|
json.dump(history, f, indent=2)
|
||||||
|
|
||||||
|
# 绘制多任务 loss 曲线图
|
||||||
|
loss_plot_path = fold_output_dir / "loss_curves.png"
|
||||||
|
plot_multitask_loss_curves(
|
||||||
|
history=history,
|
||||||
|
output_path=loss_plot_path,
|
||||||
|
title=f"Fold {fold_idx} Multi-task Loss Curves",
|
||||||
|
)
|
||||||
|
logger.info(f"Saved fold {fold_idx} loss curves to {loss_plot_path}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"fold_idx": fold_idx,
|
"fold_idx": fold_idx,
|
||||||
"best_val_loss": best_val_loss,
|
"best_val_loss": best_val_loss,
|
||||||
|
|||||||
@ -13,12 +13,18 @@ from tqdm import tqdm
|
|||||||
@dataclass
|
@dataclass
|
||||||
class LossWeights:
|
class LossWeights:
|
||||||
"""各任务的损失权重"""
|
"""各任务的损失权重"""
|
||||||
size: float = 1.0
|
# size: float = 1.0
|
||||||
pdi: float = 1.0
|
# pdi: float = 1.0
|
||||||
ee: float = 1.0
|
# ee: float = 1.0
|
||||||
|
# delivery: float = 1.0
|
||||||
|
# biodist: float = 1.0
|
||||||
|
# toxic: float = 1.0
|
||||||
|
size: float = 0.1
|
||||||
|
pdi: float = 0.3
|
||||||
|
ee: float = 0.3
|
||||||
delivery: float = 1.0
|
delivery: float = 1.0
|
||||||
biodist: float = 1.0
|
biodist: float = 1.0
|
||||||
toxic: float = 1.0
|
toxic: float = 0.05
|
||||||
|
|
||||||
|
|
||||||
def compute_multitask_loss(
|
def compute_multitask_loss(
|
||||||
|
|||||||
284
lnp_ml/modeling/visualization.py
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
"""训练过程可视化工具:绘制 loss 曲线"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg') # 使用非交互式后端,避免 GUI 依赖
|
||||||
|
|
||||||
|
# 设置中文字体支持
|
||||||
|
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'DejaVu Sans']
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
||||||
|
|
||||||
|
def plot_loss_curves(
|
||||||
|
history: Union[Dict[str, List[Dict]], List[Dict]],
|
||||||
|
output_path: Path,
|
||||||
|
title: str = "Training Loss Curves",
|
||||||
|
figsize: tuple = (12, 8),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
绘制训练过程中各个 loss 组成部分的变化曲线。
|
||||||
|
|
||||||
|
支持两种 history 格式:
|
||||||
|
1. 预训练格式(单任务):
|
||||||
|
{"train": [{"loss": 0.1, ...}, ...], "val": [{"loss": 0.1, ...}, ...]}
|
||||||
|
2. CV fold 格式:
|
||||||
|
[{"epoch": 1, "train_loss": 0.1, "val_loss": 0.1, ...}, ...]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: 训练历史记录
|
||||||
|
output_path: 输出图片路径
|
||||||
|
title: 图标题
|
||||||
|
figsize: 图片尺寸
|
||||||
|
"""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 判断 history 格式并统一处理
|
||||||
|
if isinstance(history, dict) and "train" in history:
|
||||||
|
# 格式1: {"train": [...], "val": [...]}
|
||||||
|
_plot_standard_history(history, output_path, title, figsize)
|
||||||
|
elif isinstance(history, list):
|
||||||
|
# 格式2: [{...}, {...}, ...]
|
||||||
|
_plot_flat_history(history, output_path, title, figsize)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported history format: {type(history)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_standard_history(
|
||||||
|
history: Dict[str, List[Dict]],
|
||||||
|
output_path: Path,
|
||||||
|
title: str,
|
||||||
|
figsize: tuple,
|
||||||
|
) -> None:
|
||||||
|
"""绘制标准格式的 history(train/val 分开)"""
|
||||||
|
train_history = history.get("train", [])
|
||||||
|
val_history = history.get("val", [])
|
||||||
|
|
||||||
|
if not train_history:
|
||||||
|
return
|
||||||
|
|
||||||
|
epochs = list(range(1, len(train_history) + 1))
|
||||||
|
|
||||||
|
# 收集所有 loss 键
|
||||||
|
all_loss_keys = set()
|
||||||
|
for record in train_history + val_history:
|
||||||
|
for key in record.keys():
|
||||||
|
if key.startswith("loss") or key == "loss":
|
||||||
|
all_loss_keys.add(key)
|
||||||
|
|
||||||
|
# 分离总 loss 和各任务 loss
|
||||||
|
total_loss_key = "loss"
|
||||||
|
task_loss_keys = sorted([k for k in all_loss_keys if k != "loss"])
|
||||||
|
|
||||||
|
# 创建子图
|
||||||
|
n_subplots = 1 + (1 if task_loss_keys else 0)
|
||||||
|
fig, axes = plt.subplots(n_subplots, 1, figsize=(figsize[0], figsize[1] * n_subplots / 2))
|
||||||
|
if n_subplots == 1:
|
||||||
|
axes = [axes]
|
||||||
|
|
||||||
|
# 颜色配置
|
||||||
|
colors = plt.cm.tab10.colors
|
||||||
|
|
||||||
|
# 子图1:总 loss
|
||||||
|
ax = axes[0]
|
||||||
|
train_total_loss = [r.get(total_loss_key, 0) for r in train_history]
|
||||||
|
val_total_loss = [r.get(total_loss_key, 0) for r in val_history]
|
||||||
|
|
||||||
|
ax.plot(epochs, train_total_loss, 'o-', label='Train Total Loss', color=colors[0], linewidth=2, markersize=4)
|
||||||
|
if val_total_loss:
|
||||||
|
ax.plot(epochs, val_total_loss, 's--', label='Val Total Loss', color=colors[1], linewidth=2, markersize=4)
|
||||||
|
|
||||||
|
ax.set_xlabel('Epoch', fontsize=12)
|
||||||
|
ax.set_ylabel('Loss', fontsize=12)
|
||||||
|
ax.set_title(f'{title} - Total Loss', fontsize=14, fontweight='bold')
|
||||||
|
ax.legend(loc='upper right', fontsize=10)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_xlim(0.5, len(epochs) + 0.5)
|
||||||
|
|
||||||
|
# 子图2:各任务 loss(如果有)
|
||||||
|
if task_loss_keys:
|
||||||
|
ax = axes[1]
|
||||||
|
for i, key in enumerate(task_loss_keys):
|
||||||
|
task_name = key.replace("loss_", "").upper()
|
||||||
|
train_values = [r.get(key, 0) for r in train_history]
|
||||||
|
val_values = [r.get(key, 0) for r in val_history]
|
||||||
|
|
||||||
|
color = colors[i % len(colors)]
|
||||||
|
ax.plot(epochs, train_values, 'o-', label=f'Train {task_name}', color=color, alpha=0.8, linewidth=1.5, markersize=3)
|
||||||
|
if val_values and any(v > 0 for v in val_values):
|
||||||
|
ax.plot(epochs, val_values, 's--', label=f'Val {task_name}', color=color, alpha=0.5, linewidth=1.5, markersize=3)
|
||||||
|
|
||||||
|
ax.set_xlabel('Epoch', fontsize=12)
|
||||||
|
ax.set_ylabel('Loss', fontsize=12)
|
||||||
|
ax.set_title(f'{title} - Per-Task Loss', fontsize=14, fontweight='bold')
|
||||||
|
ax.legend(loc='upper right', fontsize=9, ncol=2)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_xlim(0.5, len(epochs) + 0.5)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_flat_history(
|
||||||
|
history: List[Dict],
|
||||||
|
output_path: Path,
|
||||||
|
title: str,
|
||||||
|
figsize: tuple,
|
||||||
|
) -> None:
|
||||||
|
"""绘制扁平格式的 history(CV fold 格式)"""
|
||||||
|
if not history:
|
||||||
|
return
|
||||||
|
|
||||||
|
epochs = [r.get("epoch", i + 1) for i, r in enumerate(history)]
|
||||||
|
|
||||||
|
# 收集所有 loss 相关的键
|
||||||
|
loss_keys = set()
|
||||||
|
for record in history:
|
||||||
|
for key in record.keys():
|
||||||
|
if "loss" in key.lower():
|
||||||
|
loss_keys.add(key)
|
||||||
|
|
||||||
|
# 分类
|
||||||
|
train_keys = sorted([k for k in loss_keys if "train" in k.lower()])
|
||||||
|
val_keys = sorted([k for k in loss_keys if "val" in k.lower()])
|
||||||
|
|
||||||
|
# 创建子图
|
||||||
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
||||||
|
|
||||||
|
colors = plt.cm.tab10.colors
|
||||||
|
color_idx = 0
|
||||||
|
|
||||||
|
# 绘制训练 loss
|
||||||
|
for key in train_keys:
|
||||||
|
values = [r.get(key, 0) for r in history]
|
||||||
|
label = key.replace("_", " ").title()
|
||||||
|
ax.plot(epochs, values, 'o-', label=label, color=colors[color_idx % len(colors)],
|
||||||
|
linewidth=2, markersize=4, alpha=0.9)
|
||||||
|
color_idx += 1
|
||||||
|
|
||||||
|
# 绘制验证 loss
|
||||||
|
for key in val_keys:
|
||||||
|
values = [r.get(key, 0) for r in history]
|
||||||
|
label = key.replace("_", " ").title()
|
||||||
|
ax.plot(epochs, values, 's--', label=label, color=colors[color_idx % len(colors)],
|
||||||
|
linewidth=2, markersize=4, alpha=0.7)
|
||||||
|
color_idx += 1
|
||||||
|
|
||||||
|
ax.set_xlabel('Epoch', fontsize=12)
|
||||||
|
ax.set_ylabel('Loss', fontsize=12)
|
||||||
|
ax.set_title(title, fontsize=14, fontweight='bold')
|
||||||
|
ax.legend(loc='upper right', fontsize=10)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_xlim(0.5, len(epochs) + 0.5)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_multitask_loss_curves(
|
||||||
|
history: Dict[str, List[Dict]],
|
||||||
|
output_path: Path,
|
||||||
|
title: str = "Multi-task Training Loss",
|
||||||
|
figsize: tuple = (14, 10),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
专门用于多任务训练的 loss 曲线绘制。
|
||||||
|
|
||||||
|
将各个任务的 loss 分别绘制在不同的子图中,便于比较。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: {"train": [...], "val": [...]} 格式的训练历史
|
||||||
|
output_path: 输出路径
|
||||||
|
title: 图标题
|
||||||
|
figsize: 图片尺寸
|
||||||
|
"""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
train_history = history.get("train", [])
|
||||||
|
val_history = history.get("val", [])
|
||||||
|
|
||||||
|
if not train_history:
|
||||||
|
return
|
||||||
|
|
||||||
|
epochs = list(range(1, len(train_history) + 1))
|
||||||
|
|
||||||
|
# 提取所有任务的 loss 键
|
||||||
|
task_keys = set()
|
||||||
|
for record in train_history:
|
||||||
|
for key in record.keys():
|
||||||
|
if key.startswith("loss_"):
|
||||||
|
task_name = key.replace("loss_", "")
|
||||||
|
task_keys.add(task_name)
|
||||||
|
|
||||||
|
task_keys = sorted(task_keys)
|
||||||
|
|
||||||
|
# 计算子图布局
|
||||||
|
n_tasks = len(task_keys)
|
||||||
|
if n_tasks == 0:
|
||||||
|
# 只有总 loss,使用简单绘图
|
||||||
|
_plot_standard_history(history, output_path, title, figsize)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 包含总 loss,共 n_tasks + 1 个子图
|
||||||
|
n_plots = n_tasks + 1
|
||||||
|
n_cols = min(3, n_plots)
|
||||||
|
n_rows = (n_plots + n_cols - 1) // n_cols
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(n_rows, n_cols, figsize=(figsize[0], figsize[1] * n_rows / 2))
|
||||||
|
if n_plots == 1:
|
||||||
|
axes = [[axes]]
|
||||||
|
elif n_rows == 1:
|
||||||
|
axes = [axes]
|
||||||
|
|
||||||
|
axes_flat = [ax for row in axes for ax in (row if hasattr(row, '__iter__') else [row])]
|
||||||
|
|
||||||
|
colors = plt.cm.tab10.colors
|
||||||
|
|
||||||
|
# 子图1:总 loss
|
||||||
|
ax = axes_flat[0]
|
||||||
|
train_total = [r.get("loss", 0) for r in train_history]
|
||||||
|
val_total = [r.get("loss", 0) for r in val_history]
|
||||||
|
|
||||||
|
ax.plot(epochs, train_total, 'o-', label='Train', color=colors[0], linewidth=2, markersize=4)
|
||||||
|
if val_total:
|
||||||
|
ax.plot(epochs, val_total, 's--', label='Val', color=colors[1], linewidth=2, markersize=4)
|
||||||
|
ax.set_xlabel('Epoch')
|
||||||
|
ax.set_ylabel('Loss')
|
||||||
|
ax.set_title('Total Loss', fontweight='bold')
|
||||||
|
ax.legend(loc='upper right')
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 各任务子图
|
||||||
|
for idx, task in enumerate(task_keys):
|
||||||
|
ax = axes_flat[idx + 1]
|
||||||
|
key = f"loss_{task}"
|
||||||
|
|
||||||
|
train_values = [r.get(key, 0) for r in train_history]
|
||||||
|
val_values = [r.get(key, 0) for r in val_history]
|
||||||
|
|
||||||
|
# 只绘制有值的数据
|
||||||
|
if any(v > 0 for v in train_values):
|
||||||
|
ax.plot(epochs, train_values, 'o-', label='Train', color=colors[0], linewidth=2, markersize=4)
|
||||||
|
if val_values and any(v > 0 for v in val_values):
|
||||||
|
ax.plot(epochs, val_values, 's--', label='Val', color=colors[1], linewidth=2, markersize=4)
|
||||||
|
|
||||||
|
ax.set_xlabel('Epoch')
|
||||||
|
ax.set_ylabel('Loss')
|
||||||
|
ax.set_title(f'{task.upper()} Loss', fontweight='bold')
|
||||||
|
ax.legend(loc='upper right')
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 隐藏多余的子图
|
||||||
|
for idx in range(n_plots, len(axes_flat)):
|
||||||
|
axes_flat[idx].set_visible(False)
|
||||||
|
|
||||||
|
plt.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
@ -11,6 +11,6 @@
|
|||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"epochs": 100,
|
"epochs": 100,
|
||||||
"patience": 15,
|
"patience": 15,
|
||||||
"init_from_pretrain": null,
|
"init_from_pretrain": "models/pretrain_delivery.pt",
|
||||||
"freeze_backbone": false
|
"freeze_backbone": false
|
||||||
}
|
}
|
||||||
@ -2,38 +2,38 @@
|
|||||||
"fold_results": [
|
"fold_results": [
|
||||||
{
|
{
|
||||||
"fold_idx": 0,
|
"fold_idx": 0,
|
||||||
"best_val_loss": 5.7676777839660645,
|
"best_val_loss": 0.9860520362854004,
|
||||||
"epochs_trained": 24,
|
"epochs_trained": 40,
|
||||||
"final_train_loss": 1.4942118644714355
|
"final_train_loss": 0.5008097920152876
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fold_idx": 1,
|
"fold_idx": 1,
|
||||||
"best_val_loss": 8.418675899505615,
|
"best_val_loss": 2.4599782625834146,
|
||||||
"epochs_trained": 20,
|
"epochs_trained": 38,
|
||||||
"final_train_loss": 1.4902493238449097
|
"final_train_loss": 0.564177993271086
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fold_idx": 2,
|
"fold_idx": 2,
|
||||||
"best_val_loss": 3.5122547830854143,
|
"best_val_loss": 0.7660132050514221,
|
||||||
"epochs_trained": 25,
|
"epochs_trained": 43,
|
||||||
"final_train_loss": 1.7609570423762004
|
"final_train_loss": 0.6722757054699792
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fold_idx": 3,
|
"fold_idx": 3,
|
||||||
"best_val_loss": 3.165306806564331,
|
"best_val_loss": 1.065057098865509,
|
||||||
"epochs_trained": 21,
|
"epochs_trained": 31,
|
||||||
"final_train_loss": 2.0073827385902403
|
"final_train_loss": 0.7323974437183804
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fold_idx": 4,
|
"fold_idx": 4,
|
||||||
"best_val_loss": 2.996154228846232,
|
"best_val_loss": 1.321769932905833,
|
||||||
"epochs_trained": 18,
|
"epochs_trained": 36,
|
||||||
"final_train_loss": 1.9732873006300493
|
"final_train_loss": 0.5991987817817264
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"summary": {
|
"summary": {
|
||||||
"val_loss_mean": 4.772013900393532,
|
"val_loss_mean": 1.3197741071383158,
|
||||||
"val_loss_std": 2.0790222989111475
|
"val_loss_std": 0.5971552245392587
|
||||||
},
|
},
|
||||||
"config": {
|
"config": {
|
||||||
"d_model": 256,
|
"d_model": 256,
|
||||||
@ -48,7 +48,7 @@
|
|||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"epochs": 100,
|
"epochs": 100,
|
||||||
"patience": 15,
|
"patience": 15,
|
||||||
"init_from_pretrain": null,
|
"init_from_pretrain": "models/pretrain_delivery.pt",
|
||||||
"freeze_backbone": false
|
"freeze_backbone": false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BIN
models/finetune_cv/fold_0/loss_curves.png
Normal file
|
After Width: | Height: | Size: 287 KiB |
BIN
models/finetune_cv/fold_1/loss_curves.png
Normal file
|
After Width: | Height: | Size: 274 KiB |
BIN
models/finetune_cv/fold_2/loss_curves.png
Normal file
|
After Width: | Height: | Size: 274 KiB |
BIN
models/finetune_cv/fold_3/loss_curves.png
Normal file
|
After Width: | Height: | Size: 270 KiB |
BIN
models/finetune_cv/fold_4/loss_curves.png
Normal file
|
After Width: | Height: | Size: 278 KiB |
@ -2,293 +2,293 @@
|
|||||||
"fold_results": [
|
"fold_results": [
|
||||||
{
|
{
|
||||||
"fold_idx": 0,
|
"fold_idx": 0,
|
||||||
"n_samples": 95,
|
"n_samples": 87,
|
||||||
"size": {
|
"size": {
|
||||||
"n": 95,
|
"n": 87,
|
||||||
"rmse": 0.5909209144067168,
|
"rmse": 0.5578980261231201,
|
||||||
"mae": 0.376253614927593,
|
"mae": 0.293820194814397,
|
||||||
"r2": 0.005712927161997228
|
"r2": -0.16350852977979646
|
||||||
},
|
},
|
||||||
"delivery": {
|
"delivery": {
|
||||||
"n": 66,
|
"n": 64,
|
||||||
"rmse": 1.3280577883458438,
|
"rmse": 1.4180242556680394,
|
||||||
"mae": 0.5195405159964028,
|
"mae": 0.489678907149937,
|
||||||
"r2": 0.03195999739694366
|
"r2": 0.006066588761596381
|
||||||
},
|
},
|
||||||
"pdi": {
|
"pdi": {
|
||||||
"n": 95,
|
"n": 87,
|
||||||
"accuracy": 0.6105263157894737,
|
"accuracy": 0.7011494252873564,
|
||||||
"precision": 0.20350877192982456,
|
"precision": 0.382983682983683,
|
||||||
"recall": 0.3333333333333333,
|
"recall": 0.3981244671781756,
|
||||||
"f1": 0.25272331154684097
|
"f1": 0.38831483607603007
|
||||||
},
|
},
|
||||||
"ee": {
|
"ee": {
|
||||||
"n": 95,
|
"n": 87,
|
||||||
"accuracy": 0.6736842105263158,
|
"accuracy": 0.6206896551724138,
|
||||||
"precision": 0.22456140350877193,
|
"precision": 0.4700393567388641,
|
||||||
"recall": 0.3333333333333333,
|
"recall": 0.5084656084656085,
|
||||||
"f1": 0.26834381551362685
|
"f1": 0.4656627032698024
|
||||||
},
|
},
|
||||||
"toxic": {
|
"toxic": {
|
||||||
"n": 66,
|
"n": 65,
|
||||||
"accuracy": 0.8939393939393939,
|
"accuracy": 0.9846153846153847,
|
||||||
"precision": 0.44696969696969696,
|
"precision": 0.9918032786885246,
|
||||||
"recall": 0.5,
|
"recall": 0.9,
|
||||||
"f1": 0.472
|
"f1": 0.9403122130394859
|
||||||
},
|
},
|
||||||
"biodist": {
|
"biodist": {
|
||||||
"n": 66,
|
"n": 65,
|
||||||
"kl_divergence": 0.851655784204727,
|
"kl_divergence": 0.20084619060054956,
|
||||||
"js_divergence": 0.21404831573756974
|
"js_divergence": 0.048292698388857934
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fold_idx": 1,
|
"fold_idx": 1,
|
||||||
"n_samples": 195,
|
"n_samples": 87,
|
||||||
"size": {
|
"size": {
|
||||||
"n": 193,
|
"n": 87,
|
||||||
"rmse": 0.4425801645813746,
|
"rmse": 0.40406802223294774,
|
||||||
"mae": 0.26432527161632796,
|
"mae": 0.2981993521767101,
|
||||||
"r2": -0.026225211870033682
|
"r2": 0.09717283856858072
|
||||||
},
|
},
|
||||||
"delivery": {
|
"delivery": {
|
||||||
"n": 123,
|
"n": 58,
|
||||||
"rmse": 0.7771322048436382,
|
"rmse": 0.545857341773909,
|
||||||
"mae": 0.6133777339870822,
|
"mae": 0.3866841504832023,
|
||||||
"r2": -0.128644776760948
|
"r2": 0.39045172405838346
|
||||||
},
|
},
|
||||||
"pdi": {
|
"pdi": {
|
||||||
"n": 195,
|
"n": 87,
|
||||||
"accuracy": 0.7076923076923077,
|
"accuracy": 0.8045977011494253,
|
||||||
"precision": 0.35384615384615387,
|
"precision": 0.49282452707110247,
|
||||||
"recall": 0.5,
|
"recall": 0.45567765567765567,
|
||||||
"f1": 0.4144144144144144
|
"f1": 0.4661145617667357
|
||||||
},
|
},
|
||||||
"ee": {
|
"ee": {
|
||||||
"n": 195,
|
"n": 87,
|
||||||
"accuracy": 0.4205128205128205,
|
"accuracy": 0.7241379310344828,
|
||||||
"precision": 0.14017094017094017,
|
"precision": 0.7064586357039188,
|
||||||
"recall": 0.3333333333333333,
|
"recall": 0.6503496503496503,
|
||||||
"f1": 0.19735258724428398
|
"f1": 0.6654761904761903
|
||||||
},
|
},
|
||||||
"toxic": {
|
"toxic": {
|
||||||
"n": 123,
|
"n": 59,
|
||||||
"accuracy": 1.0,
|
"accuracy": 0.9830508474576272,
|
||||||
"precision": 1.0,
|
"precision": 0.9912280701754386,
|
||||||
"recall": 1.0,
|
"recall": 0.8333333333333333,
|
||||||
"f1": 1.0
|
"f1": 0.8955752212389381
|
||||||
},
|
},
|
||||||
"biodist": {
|
"biodist": {
|
||||||
"n": 123,
|
"n": 58,
|
||||||
"kl_divergence": 0.9336461102028436,
|
"kl_divergence": 0.185822128176519,
|
||||||
"js_divergence": 0.24870266224462317
|
"js_divergence": 0.049566546350752166
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fold_idx": 2,
|
"fold_idx": 2,
|
||||||
"n_samples": 51,
|
"n_samples": 87,
|
||||||
"size": {
|
"size": {
|
||||||
"n": 51,
|
"n": 86,
|
||||||
"rmse": 0.6473513298834871,
|
"rmse": 0.5861093745094258,
|
||||||
"mae": 0.5600235602434944,
|
"mae": 0.35274335949919944,
|
||||||
"r2": -9.27515642706235
|
"r2": -0.3079648452189143
|
||||||
},
|
},
|
||||||
"delivery": {
|
"delivery": {
|
||||||
"n": 44,
|
"n": 61,
|
||||||
"rmse": 0.7721077356414991,
|
"rmse": 0.5034529339588798,
|
||||||
"mae": 0.6167582499593581,
|
"mae": 0.3725305872618175,
|
||||||
"r2": -0.4822886602727561
|
"r2": 0.596618413667312
|
||||||
},
|
},
|
||||||
"pdi": {
|
"pdi": {
|
||||||
"n": 51,
|
"n": 87,
|
||||||
"accuracy": 0.8823529411764706,
|
"accuracy": 0.7701149425287356,
|
||||||
"precision": 0.29411764705882354,
|
"precision": 0.7266666666666666,
|
||||||
"recall": 0.3333333333333333,
|
"recall": 0.6349206349206349,
|
||||||
"f1": 0.3125
|
"f1": 0.6497584541062802
|
||||||
},
|
},
|
||||||
"ee": {
|
"ee": {
|
||||||
"n": 51,
|
"n": 87,
|
||||||
"accuracy": 0.8431372549019608,
|
"accuracy": 0.5172413793103449,
|
||||||
"precision": 0.28104575163398693,
|
"precision": 0.43155828639699606,
|
||||||
"recall": 0.3333333333333333,
|
"recall": 0.40000766812361016,
|
||||||
"f1": 0.3049645390070922
|
"f1": 0.3980599647266314
|
||||||
},
|
},
|
||||||
"toxic": {
|
"toxic": {
|
||||||
"n": 47,
|
"n": 61,
|
||||||
"accuracy": 0.851063829787234,
|
"accuracy": 1.0,
|
||||||
"precision": 0.425531914893617,
|
"precision": 1.0,
|
||||||
"recall": 0.5,
|
"recall": 1.0,
|
||||||
"f1": 0.4597701149425288
|
"f1": 1.0
|
||||||
},
|
},
|
||||||
"biodist": {
|
"biodist": {
|
||||||
"n": 45,
|
"n": 61,
|
||||||
"kl_divergence": 1.1049896129018548,
|
"kl_divergence": 0.2646404257700098,
|
||||||
"js_divergence": 0.25485248115851133
|
"js_divergence": 0.07024299955112
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fold_idx": 3,
|
"fold_idx": 3,
|
||||||
"n_samples": 66,
|
"n_samples": 86,
|
||||||
"size": {
|
"size": {
|
||||||
"n": 66,
|
"n": 86,
|
||||||
"rmse": 0.2407212117920812,
|
"rmse": 0.32742961478246685,
|
||||||
"mae": 0.19363613562150436,
|
"mae": 0.25193805472795355,
|
||||||
"r2": -0.11204941379936861
|
"r2": -0.09589933555096875
|
||||||
},
|
},
|
||||||
"delivery": {
|
"delivery": {
|
||||||
"n": 62,
|
"n": 68,
|
||||||
"rmse": 1.0041711455927012,
|
"rmse": 0.7277366648519259,
|
||||||
"mae": 0.7132550483914993,
|
"mae": 0.42998586144462664,
|
||||||
"r2": -0.63265374674746
|
"r2": 0.4053674615039361
|
||||||
},
|
},
|
||||||
"pdi": {
|
"pdi": {
|
||||||
"n": 66,
|
"n": 86,
|
||||||
"accuracy": 0.8484848484848485,
|
"accuracy": 0.7906976744186046,
|
||||||
"precision": 0.42424242424242425,
|
"precision": 0.7513157894736842,
|
||||||
"recall": 0.5,
|
"recall": 0.6356534090909091,
|
||||||
"f1": 0.4590163934426229
|
"f1": 0.6544642857142857
|
||||||
},
|
},
|
||||||
"ee": {
|
"ee": {
|
||||||
"n": 66,
|
"n": 86,
|
||||||
"accuracy": 0.8181818181818182,
|
"accuracy": 0.7441860465116279,
|
||||||
"precision": 0.27692307692307694,
|
"precision": 0.6962905144216474,
|
||||||
"recall": 0.32727272727272727,
|
"recall": 0.6327243018419488,
|
||||||
"f1": 0.3
|
"f1": 0.6557768628760061
|
||||||
},
|
},
|
||||||
"toxic": {
|
"toxic": {
|
||||||
"n": 62,
|
"n": 68,
|
||||||
"accuracy": 1.0,
|
"accuracy": 1.0,
|
||||||
"precision": 1.0,
|
"precision": 1.0,
|
||||||
"recall": 1.0,
|
"recall": 1.0,
|
||||||
"f1": 1.0
|
"f1": 1.0
|
||||||
},
|
},
|
||||||
"biodist": {
|
"biodist": {
|
||||||
"n": 62,
|
"n": 68,
|
||||||
"kl_divergence": 0.9677978984139058,
|
"kl_divergence": 0.3411994760293469,
|
||||||
"js_divergence": 0.2020309307244639
|
"js_divergence": 0.07812197338009717
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"fold_idx": 4,
|
"fold_idx": 4,
|
||||||
"n_samples": 27,
|
"n_samples": 87,
|
||||||
"size": {
|
"size": {
|
||||||
"n": 27,
|
"n": 86,
|
||||||
"rmse": 0.23392834445509142,
|
"rmse": 0.28795189907825647,
|
||||||
"mae": 0.19066280788845485,
|
"mae": 0.21156823080639506,
|
||||||
"r2": -0.2667651950955112
|
"r2": 0.2077479731717109
|
||||||
},
|
},
|
||||||
"delivery": {
|
"delivery": {
|
||||||
"n": 15,
|
"n": 59,
|
||||||
"rmse": 1.9603892288630869,
|
"rmse": 0.8048179107025805,
|
||||||
"mae": 1.3892907698949177,
|
"mae": 0.5188011898327682,
|
||||||
"r2": -0.29760739742916287
|
"r2": 0.24048521206149798
|
||||||
},
|
},
|
||||||
"pdi": {
|
"pdi": {
|
||||||
"n": 27,
|
"n": 87,
|
||||||
"accuracy": 0.8888888888888888,
|
"accuracy": 0.6896551724137931,
|
||||||
"precision": 0.4444444444444444,
|
"precision": 0.4101075268817204,
|
||||||
"recall": 0.5,
|
"recall": 0.425,
|
||||||
"f1": 0.47058823529411764
|
"f1": 0.4174194267871083
|
||||||
},
|
},
|
||||||
"ee": {
|
"ee": {
|
||||||
"n": 27,
|
"n": 87,
|
||||||
"accuracy": 0.5925925925925926,
|
"accuracy": 0.7011494252873564,
|
||||||
"precision": 0.19753086419753085,
|
"precision": 0.6600529100529101,
|
||||||
"recall": 0.3333333333333333,
|
"recall": 0.581219806763285,
|
||||||
"f1": 0.24806201550387597
|
"f1": 0.5953238953238954
|
||||||
},
|
},
|
||||||
"toxic": {
|
"toxic": {
|
||||||
"n": 15,
|
"n": 60,
|
||||||
"accuracy": 1.0,
|
"accuracy": 0.95,
|
||||||
"precision": 1.0,
|
"precision": 0.8240740740740741,
|
||||||
"recall": 1.0,
|
"recall": 0.8818181818181818,
|
||||||
"f1": 1.0
|
"f1": 0.8498748957464553
|
||||||
},
|
},
|
||||||
"biodist": {
|
"biodist": {
|
||||||
"n": 15,
|
"n": 59,
|
||||||
"kl_divergence": 0.9389607012315264,
|
"kl_divergence": 0.207699088365002,
|
||||||
"js_divergence": 0.2470218476598176
|
"js_divergence": 0.05288953180347253
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"summary_stats": {
|
"summary_stats": {
|
||||||
"size": {
|
"size": {
|
||||||
"rmse_mean": 0.43110039302375025,
|
"rmse_mean": 0.43269138734524343,
|
||||||
"rmse_std": 0.17179051271013462,
|
"rmse_std": 0.12005218734930377,
|
||||||
"r2_mean": -1.9348966641330534,
|
"r2_mean": -0.05249037976187758,
|
||||||
"r2_std": 3.6713441784129
|
"r2_std": 0.18417364026118202
|
||||||
},
|
},
|
||||||
"delivery": {
|
"delivery": {
|
||||||
"rmse_mean": 1.1683716206573538,
|
"rmse_mean": 0.7999778213910669,
|
||||||
"rmse_std": 0.4449374578352648,
|
"rmse_std": 0.3285507063187239,
|
||||||
"r2_mean": -0.30184691676267666,
|
"r2_mean": 0.32779788001054516,
|
||||||
"r2_std": 0.23809090378746706
|
"r2_std": 0.19664259417310184
|
||||||
},
|
},
|
||||||
"pdi": {
|
"pdi": {
|
||||||
"accuracy_mean": 0.7875890604063979,
|
"accuracy_mean": 0.751242983159583,
|
||||||
"accuracy_std": 0.11016791908756088,
|
"accuracy_std": 0.04703609766404652,
|
||||||
"f1_mean": 0.3818484709395992,
|
"f1_mean": 0.515214312890088,
|
||||||
"f1_std": 0.08529090446864619
|
"f1_std": 0.11451705845950987
|
||||||
},
|
},
|
||||||
"ee": {
|
"ee": {
|
||||||
"accuracy_mean": 0.6696217393431015,
|
"accuracy_mean": 0.6614808874632452,
|
||||||
"accuracy_std": 0.15503740047242787,
|
"accuracy_std": 0.08343692429542732,
|
||||||
"f1_mean": 0.2637445914537758,
|
"f1_mean": 0.5560599233345052,
|
||||||
"f1_std": 0.039213602228007696
|
"f1_std": 0.10638861908930271
|
||||||
},
|
},
|
||||||
"toxic": {
|
"toxic": {
|
||||||
"accuracy_mean": 0.9490006447453256,
|
"accuracy_mean": 0.9835332464146024,
|
||||||
"accuracy_std": 0.06391582554207781,
|
"accuracy_std": 0.018265761928954075,
|
||||||
"f1_mean": 0.7863540229885058,
|
"f1_mean": 0.9371524660049758,
|
||||||
"f1_std": 0.26169039387919035
|
"f1_std": 0.05874632007606866
|
||||||
},
|
},
|
||||||
"biodist": {
|
"biodist": {
|
||||||
"kl_mean": 0.9594100213909715,
|
"kl_mean": 0.24004146178828548,
|
||||||
"kl_std": 0.08240959093662605,
|
"kl_std": 0.05720155140134621,
|
||||||
"js_mean": 0.23333124750499712,
|
"js_mean": 0.05982274989485996,
|
||||||
"js_std": 0.021158533549255752
|
"js_std": 0.012080103435264247
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"overall": {
|
"overall": {
|
||||||
"size": {
|
"size": {
|
||||||
"n_samples": 432,
|
"n_samples": 432,
|
||||||
"mse": 0.22604480336185886,
|
"mse": 0.20179930058631468,
|
||||||
"rmse": 0.47544169291497657,
|
"rmse": 0.44922077043065883,
|
||||||
"mae": 0.3084443360567093,
|
"mae": 0.2817203010673876,
|
||||||
"r2": -0.2313078534105617
|
"r2": -0.09923811531698323
|
||||||
},
|
},
|
||||||
"delivery": {
|
"delivery": {
|
||||||
"n_samples": 310,
|
"n_samples": 310,
|
||||||
"mse": 1.0873755440675295,
|
"mse": 0.760202623547859,
|
||||||
"rmse": 1.0427730069710903,
|
"rmse": 0.8718959935381393,
|
||||||
"mae": 0.6513989447841361,
|
"mae": 0.4398058238289049,
|
||||||
"r2": -0.09443640807387799
|
"r2": 0.23486100707044422
|
||||||
},
|
},
|
||||||
"pdi": {
|
"pdi": {
|
||||||
"n_samples": 434,
|
"n_samples": 434,
|
||||||
"accuracy": 0.7396313364055299,
|
"accuracy": 0.7511520737327189,
|
||||||
"precision": 0.18490783410138248,
|
"precision": 0.3287852263755878,
|
||||||
"recall": 0.25,
|
"recall": 0.3184060228452752,
|
||||||
"f1": 0.21258278145695364
|
"f1": 0.3212571677885814
|
||||||
},
|
},
|
||||||
"ee": {
|
"ee": {
|
||||||
"n_samples": 434,
|
"n_samples": 434,
|
||||||
"accuracy": 0.5967741935483871,
|
"accuracy": 0.6612903225806451,
|
||||||
"precision": 0.1993841416474211,
|
"precision": 0.5836247086247086,
|
||||||
"recall": 0.33205128205128204,
|
"recall": 0.5469842657342657,
|
||||||
"f1": 0.24915824915824913
|
"f1": 0.5597637622559741
|
||||||
},
|
},
|
||||||
"toxic": {
|
"toxic": {
|
||||||
"n_samples": 313,
|
"n_samples": 313,
|
||||||
"accuracy": 0.9552715654952076,
|
"accuracy": 0.9840255591054313,
|
||||||
"precision": 0.4776357827476038,
|
"precision": 0.918076923076923,
|
||||||
"recall": 0.5,
|
"recall": 0.8895126612517916,
|
||||||
"f1": 0.48856209150326796
|
"f1": 0.9032337847028998
|
||||||
},
|
},
|
||||||
"biodist": {
|
"biodist": {
|
||||||
"n_samples": 311,
|
"n_samples": 311,
|
||||||
"kl_divergence": 0.9481034280166569,
|
"kl_divergence": 0.24254521665201,
|
||||||
"js_divergence": 0.23285280825310384
|
"js_divergence": 0.06022985409160514
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1,205 +1,293 @@
|
|||||||
{
|
{
|
||||||
"train": [
|
"train": [
|
||||||
{
|
{
|
||||||
"loss": 0.7730368412685099,
|
"loss": 0.7867247516051271,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.658895703010919,
|
"loss": 0.6515523084370274,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6059015260392299,
|
"loss": 0.5990842743185651,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5744731174349416,
|
"loss": 0.5633418128920326,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5452056020458733,
|
"loss": 0.5453761521296815,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5138543470936083,
|
"loss": 0.49953126002250825,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4885380559178135,
|
"loss": 0.49147369265204843,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.47587182296687974,
|
"loss": 0.4659397622863399,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4671051038255316,
|
"loss": 0.4653009635305819,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.46794115915756107,
|
"loss": 0.4380375076610923,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4293930456997915,
|
"loss": 0.4258159104875806,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.42624105651716415,
|
"loss": 0.4144523660948226,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4131358770446828,
|
"loss": 0.4008358244841981,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3946074267790835,
|
"loss": 0.40240038808127093,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3898155013755344,
|
"loss": 0.38176763174141226,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.37861797005733383,
|
"loss": 0.37277743237904,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3775682858392304,
|
"loss": 0.3573742728176747,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3800349080262064,
|
"loss": 0.3491767022517619,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.36302345173031675,
|
"loss": 0.3499675623860557,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3429561740842766,
|
"loss": 0.35079578643841286,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3445638883004898,
|
"loss": 0.3414057292594471,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.318970229203733,
|
"loss": 0.300529963052257,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.30179278279904437,
|
"loss": 0.2961940902990875,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.2887343142006437,
|
"loss": 0.2907693515383844,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.29240367556855545,
|
"loss": 0.2809350616734551,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.28143580470326973,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.2664423378391215,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.2745858784487654,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.26682337215652197,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.2681302405486289,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.26258669999889017,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.2608744821883436,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.239722755447208,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.24175641130912484,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.23785491213674798,
|
||||||
|
"n_samples": 6783
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.23117999019839675,
|
||||||
"n_samples": 6783
|
"n_samples": 6783
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"val": [
|
"val": [
|
||||||
{
|
{
|
||||||
"loss": 0.7350345371841441,
|
"loss": 0.7379055186813953,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.7165568811318536,
|
"loss": 0.7269540305477178,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.7251406249862214,
|
"loss": 0.6927794775152518,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6836505264587159,
|
"loss": 0.6652627856533758,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6747132955771933,
|
"loss": 0.6721626692594103,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6691136244936912,
|
"loss": 0.6660812889787394,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6337480902323249,
|
"loss": 0.6329412659009298,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6600317959527934,
|
"loss": 0.6395346636332554,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6439923948855346,
|
"loss": 0.6228037932749914,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.643800035575267,
|
"loss": 0.627341245329581,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6181512585221839,
|
"loss": 0.6399482499704272,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6442458634939151,
|
"loss": 0.6136556721283145,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6344759362359862,
|
"loss": 0.6253821217484764,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6501405371457472,
|
"loss": 0.6535878175511408,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6098835162990152,
|
"loss": 0.6123772015635566,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6366627322138894,
|
"loss": 0.648116178003258,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6171610150646417,
|
"loss": 0.6141229696318092,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6358801012273748,
|
"loss": 0.6307853255273552,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6239976831059871,
|
"loss": 0.6422428293329848,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6683828232827201,
|
"loss": 0.6552245357949421,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6655785786478143,
|
"loss": 0.6614342853503823,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6152775046503088,
|
"loss": 0.6153881246378465,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6202247662153858,
|
"loss": 0.6298057977526632,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.648199727435189,
|
"loss": 0.6117005377321011,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6473217075085124,
|
"loss": 0.618167482426702,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6103124622220963,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6203263100988888,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.617202088939065,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6263537556599373,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6461186489679168,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6289772454163526,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6347806630652919,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6358688624302136,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.646931814478975,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6245040218978556,
|
||||||
|
"n_samples": 2907
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.6267098000482632,
|
||||||
"n_samples": 2907
|
"n_samples": 2907
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
BIN
reports/figures/before_loss_weighting/loss_curves 0.png
Normal file
|
After Width: | Height: | Size: 278 KiB |
BIN
reports/figures/before_loss_weighting/loss_curves 1.png
Normal file
|
After Width: | Height: | Size: 274 KiB |
BIN
reports/figures/before_loss_weighting/loss_curves 2.png
Normal file
|
After Width: | Height: | Size: 274 KiB |
BIN
reports/figures/before_loss_weighting/loss_curves 3.png
Normal file
|
After Width: | Height: | Size: 277 KiB |
BIN
reports/figures/before_loss_weighting/loss_curves 4.png
Normal file
|
After Width: | Height: | Size: 280 KiB |