Fix makefile

This commit is contained in:
RYDE-WORK 2026-01-19 18:06:04 +08:00
parent cbfbd1a7af
commit 88f7b51b07
2 changed files with 155 additions and 1 deletions

View File

@ -84,7 +84,12 @@ FREEZE_FLAG = $(if $(FREEZE_BACKBONE),--freeze-backbone,)
## Pretrain on external data (delivery only)
.PHONY: pretrain
pretrain: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain $(MPNN_FLAG)
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG)
## Evaluate pretrain model (delivery metrics)
.PHONY: test_pretrain
test_pretrain: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG)
## Train model (multi-task, from scratch)
.PHONY: train

View File

@ -356,6 +356,155 @@ def main(
)
@app.command()
def test(
# 数据路径
val_path: Path = PROCESSED_DATA_DIR / "val_pretrain.parquet",
model_path: Path = MODELS_DIR / "pretrain_delivery.pt",
output_path: Path = MODELS_DIR / "pretrain_test_results.json",
# MPNN 参数
use_mpnn: bool = False,
mpnn_device: str = "cpu",
# 其他参数
batch_size: int = 64,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
评估 pretrain 模型在外部数据上的 delivery 预测性能
输出详细指标MSE, RMSE, MAE,
"""
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
logger.info(f"Using device: {device}")
device_obj = torch.device(device)
# 加载模型
logger.info(f"Loading pretrain model from {model_path}")
checkpoint = torch.load(model_path, map_location=device_obj)
config = checkpoint["config"]
# 解析 MPNN 配置
enable_mpnn = config.get("use_mpnn", False)
if enable_mpnn or use_mpnn:
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
ensemble_paths = find_mpnn_ensemble_paths()
logger.info(f"Found {len(ensemble_paths)} MPNN models")
model = LNPModel(
d_model=config["d_model"],
num_heads=config["num_heads"],
n_attn_layers=config["n_attn_layers"],
fusion_strategy=config["fusion_strategy"],
head_hidden_dim=config["head_hidden_dim"],
dropout=config["dropout"],
mpnn_ensemble_paths=ensemble_paths,
mpnn_device=mpnn_device,
)
else:
model = LNPModelWithoutMPNN(
d_model=config["d_model"],
num_heads=config["num_heads"],
n_attn_layers=config["n_attn_layers"],
fusion_strategy=config["fusion_strategy"],
head_hidden_dim=config["head_hidden_dim"],
dropout=config["dropout"],
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device_obj)
model.eval()
logger.info(f"Model config: {config}")
logger.info(f"Best val_loss from training: {checkpoint.get('best_val_loss', 'N/A')}")
# 加载数据
logger.info(f"Loading validation data from {val_path}")
val_df = pd.read_parquet(val_path)
val_dataset = ExternalDeliveryDataset(val_df)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
logger.info(f"Validation samples: {len(val_dataset)}")
# 预测
logger.info("Running predictions...")
all_preds = []
all_targets = []
all_masks = []
with torch.no_grad():
for batch in tqdm(val_loader, desc="Predicting"):
smiles = batch["smiles"]
tabular = {k: v.to(device_obj) for k, v in batch["tabular"].items()}
targets = batch["targets"]["delivery"].numpy()
mask = batch["mask"]["delivery"].numpy()
pred = model.forward_delivery(smiles, tabular).squeeze(-1).cpu().numpy()
all_preds.extend(pred)
all_targets.extend(targets)
all_masks.extend(mask)
# 转为数组
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)
all_masks = np.array(all_masks, dtype=bool)
# 只计算有效样本
y_pred = all_preds[all_masks]
y_true = all_targets[all_masks]
# 计算指标
mse = float(mean_squared_error(y_true, y_pred))
rmse = float(np.sqrt(mse))
mae = float(mean_absolute_error(y_true, y_pred))
r2 = float(r2_score(y_true, y_pred))
# 额外统计
correlation = float(np.corrcoef(y_true, y_pred)[0, 1])
results = {
"model_path": str(model_path),
"val_path": str(val_path),
"n_samples": int(all_masks.sum()),
"metrics": {
"mse": mse,
"rmse": rmse,
"mae": mae,
"r2": r2,
"correlation": correlation,
},
"statistics": {
"y_true_mean": float(y_true.mean()),
"y_true_std": float(y_true.std()),
"y_pred_mean": float(y_pred.mean()),
"y_pred_std": float(y_pred.std()),
}
}
# 打印结果
logger.info("\n" + "=" * 50)
logger.info("PRETRAIN MODEL EVALUATION (Delivery)")
logger.info("=" * 50)
logger.info(f"Samples: {results['n_samples']}")
logger.info("\n[Metrics]")
logger.info(f" MSE: {mse:.4f}")
logger.info(f" RMSE: {rmse:.4f}")
logger.info(f" MAE: {mae:.4f}")
logger.info(f" R²: {r2:.4f}")
logger.info(f" Correlation: {correlation:.4f}")
logger.info("\n[Statistics]")
logger.info(f" True: mean={y_true.mean():.4f}, std={y_true.std():.4f}")
logger.info(f" Pred: mean={y_pred.mean():.4f}, std={y_pred.std():.4f}")
# 保存结果
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
logger.success(f"\nSaved results to {output_path}")
if __name__ == "__main__":
app()