lnp_ml/scripts/preprocess_internal.py

79 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""数据清洗脚本:修正原始数据中的问题"""
import json
from pathlib import Path
import numpy as np
import pandas as pd
import typer
from loguru import logger
from lnp_ml.config import RAW_DATA_DIR, INTERIM_DATA_DIR
APP_DIR = Path(__file__).resolve().parents[1] / "app"
app = typer.Typer()
@app.command()
def main(
input_path: Path = RAW_DATA_DIR / "internal.xlsx",
output_path: Path = INTERIM_DATA_DIR / "internal.csv",
):
"""
预处理内部数据,按给药途径分组进行 z-score 标准化,对 size 列取 log。
修正内容:
1. 按给药途径分组进行 z-score 标准化
2. 对 size 列取 log
3. 将 z-score 的 mean/std 保存到 app/ 供推理时反推
"""
logger.info(f"Loading data from {input_path}")
df = pd.read_excel(input_path, header=2)
logger.info(f"Loaded {len(df)} samples")
# 分别对肌肉注射组和静脉注射组重新进行 z-score 标准化
logger.info("Z-score normalizing delivery by Route_of_administration...")
df["unnormalized_delivery"] = pd.to_numeric(df["unnormalized_delivery"], errors="coerce")
# 计算并保存 per-route 统计量,用于推理时反推和评分归一化
zscore_stats = {}
for route, group in df.groupby("Route_of_administration"):
vals = group["unnormalized_delivery"].dropna()
if len(vals) > 1:
zscore_stats[route] = {"mean": float(vals.mean()), "std": float(vals.std())}
logger.info(f" {route}: mean={vals.mean():.6f}, std={vals.std():.6f}, n={len(vals)}")
df["quantified_delivery"] = (
df.groupby("Route_of_administration")["unnormalized_delivery"]
.transform(lambda x: (x - x.mean()) / x.std())
)
# 补充 quantified_delivery 的 per-route min/max用于评分时归一化到 [0,1]
for route, group in df.groupby("Route_of_administration"):
qd = group["quantified_delivery"].dropna()
if len(qd) > 0 and route in zscore_stats:
zscore_stats[route]["qd_min"] = float(qd.min())
zscore_stats[route]["qd_max"] = float(qd.max())
logger.info(f" {route}: qd_min={qd.min():.6f}, qd_max={qd.max():.6f}")
stats_path = APP_DIR / "delivery_zscore_stats.json"
stats_path.parent.mkdir(parents=True, exist_ok=True)
with open(stats_path, "w") as f:
json.dump(zscore_stats, f, indent=2)
logger.success(f"Saved delivery stats to {stats_path}")
# 对 size 列取 log
logger.info("Log-transforming size column...")
df["size"] = pd.to_numeric(df["size"], errors="coerce")
df["size"] = np.log(df["size"].replace(0, np.nan)) # 避免 log(0)
# 保存
output_path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(output_path, index=False)
logger.success(f"Saved cleaned data to {output_path}")
if __name__ == "__main__":
app()