mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-01-19 20:03:14 +08:00
Fix makefile
This commit is contained in:
parent
cbfbd1a7af
commit
88f7b51b07
7
Makefile
7
Makefile
@ -84,7 +84,12 @@ FREEZE_FLAG = $(if $(FREEZE_BACKBONE),--freeze-backbone,)
|
||||
## Pretrain on external data (delivery only)
|
||||
.PHONY: pretrain
|
||||
pretrain: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain $(MPNN_FLAG)
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain main $(MPNN_FLAG)
|
||||
|
||||
## Evaluate pretrain model (delivery metrics)
|
||||
.PHONY: test_pretrain
|
||||
test_pretrain: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG)
|
||||
|
||||
## Train model (multi-task, from scratch)
|
||||
.PHONY: train
|
||||
|
||||
@ -356,6 +356,155 @@ def main(
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def test(
|
||||
# 数据路径
|
||||
val_path: Path = PROCESSED_DATA_DIR / "val_pretrain.parquet",
|
||||
model_path: Path = MODELS_DIR / "pretrain_delivery.pt",
|
||||
output_path: Path = MODELS_DIR / "pretrain_test_results.json",
|
||||
# MPNN 参数
|
||||
use_mpnn: bool = False,
|
||||
mpnn_device: str = "cpu",
|
||||
# 其他参数
|
||||
batch_size: int = 64,
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
"""
|
||||
评估 pretrain 模型在外部数据上的 delivery 预测性能。
|
||||
|
||||
输出详细指标:MSE, RMSE, MAE, R²
|
||||
"""
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
device_obj = torch.device(device)
|
||||
|
||||
# 加载模型
|
||||
logger.info(f"Loading pretrain model from {model_path}")
|
||||
checkpoint = torch.load(model_path, map_location=device_obj)
|
||||
config = checkpoint["config"]
|
||||
|
||||
# 解析 MPNN 配置
|
||||
enable_mpnn = config.get("use_mpnn", False)
|
||||
if enable_mpnn or use_mpnn:
|
||||
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
|
||||
ensemble_paths = find_mpnn_ensemble_paths()
|
||||
logger.info(f"Found {len(ensemble_paths)} MPNN models")
|
||||
|
||||
model = LNPModel(
|
||||
d_model=config["d_model"],
|
||||
num_heads=config["num_heads"],
|
||||
n_attn_layers=config["n_attn_layers"],
|
||||
fusion_strategy=config["fusion_strategy"],
|
||||
head_hidden_dim=config["head_hidden_dim"],
|
||||
dropout=config["dropout"],
|
||||
mpnn_ensemble_paths=ensemble_paths,
|
||||
mpnn_device=mpnn_device,
|
||||
)
|
||||
else:
|
||||
model = LNPModelWithoutMPNN(
|
||||
d_model=config["d_model"],
|
||||
num_heads=config["num_heads"],
|
||||
n_attn_layers=config["n_attn_layers"],
|
||||
fusion_strategy=config["fusion_strategy"],
|
||||
head_hidden_dim=config["head_hidden_dim"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.to(device_obj)
|
||||
model.eval()
|
||||
|
||||
logger.info(f"Model config: {config}")
|
||||
logger.info(f"Best val_loss from training: {checkpoint.get('best_val_loss', 'N/A')}")
|
||||
|
||||
# 加载数据
|
||||
logger.info(f"Loading validation data from {val_path}")
|
||||
val_df = pd.read_parquet(val_path)
|
||||
val_dataset = ExternalDeliveryDataset(val_df)
|
||||
val_loader = DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
logger.info(f"Validation samples: {len(val_dataset)}")
|
||||
|
||||
# 预测
|
||||
logger.info("Running predictions...")
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
all_masks = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(val_loader, desc="Predicting"):
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device_obj) for k, v in batch["tabular"].items()}
|
||||
targets = batch["targets"]["delivery"].numpy()
|
||||
mask = batch["mask"]["delivery"].numpy()
|
||||
|
||||
pred = model.forward_delivery(smiles, tabular).squeeze(-1).cpu().numpy()
|
||||
|
||||
all_preds.extend(pred)
|
||||
all_targets.extend(targets)
|
||||
all_masks.extend(mask)
|
||||
|
||||
# 转为数组
|
||||
all_preds = np.array(all_preds)
|
||||
all_targets = np.array(all_targets)
|
||||
all_masks = np.array(all_masks, dtype=bool)
|
||||
|
||||
# 只计算有效样本
|
||||
y_pred = all_preds[all_masks]
|
||||
y_true = all_targets[all_masks]
|
||||
|
||||
# 计算指标
|
||||
mse = float(mean_squared_error(y_true, y_pred))
|
||||
rmse = float(np.sqrt(mse))
|
||||
mae = float(mean_absolute_error(y_true, y_pred))
|
||||
r2 = float(r2_score(y_true, y_pred))
|
||||
|
||||
# 额外统计
|
||||
correlation = float(np.corrcoef(y_true, y_pred)[0, 1])
|
||||
|
||||
results = {
|
||||
"model_path": str(model_path),
|
||||
"val_path": str(val_path),
|
||||
"n_samples": int(all_masks.sum()),
|
||||
"metrics": {
|
||||
"mse": mse,
|
||||
"rmse": rmse,
|
||||
"mae": mae,
|
||||
"r2": r2,
|
||||
"correlation": correlation,
|
||||
},
|
||||
"statistics": {
|
||||
"y_true_mean": float(y_true.mean()),
|
||||
"y_true_std": float(y_true.std()),
|
||||
"y_pred_mean": float(y_pred.mean()),
|
||||
"y_pred_std": float(y_pred.std()),
|
||||
}
|
||||
}
|
||||
|
||||
# 打印结果
|
||||
logger.info("\n" + "=" * 50)
|
||||
logger.info("PRETRAIN MODEL EVALUATION (Delivery)")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Samples: {results['n_samples']}")
|
||||
logger.info("\n[Metrics]")
|
||||
logger.info(f" MSE: {mse:.4f}")
|
||||
logger.info(f" RMSE: {rmse:.4f}")
|
||||
logger.info(f" MAE: {mae:.4f}")
|
||||
logger.info(f" R²: {r2:.4f}")
|
||||
logger.info(f" Correlation: {correlation:.4f}")
|
||||
logger.info("\n[Statistics]")
|
||||
logger.info(f" True: mean={y_true.mean():.4f}, std={y_true.std():.4f}")
|
||||
logger.info(f" Pred: mean={y_pred.mean():.4f}, std={y_pred.std():.4f}")
|
||||
|
||||
# 保存结果
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
logger.success(f"\nSaved results to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user