diff --git a/app/api.py b/app/api.py index 04b91c8..099eabb 100644 --- a/app/api.py +++ b/app/api.py @@ -31,17 +31,17 @@ from app.optimize import ( # ============ Pydantic Models ============ class CompRangesRequest(BaseModel): - """组分范围配置""" - weight_ratio_min: float = Field(default=0.05, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最小值") - weight_ratio_max: float = Field(default=0.30, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最大值") - cationic_mol_min: float = Field(default=0.05, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最小值") - cationic_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最大值") - phospholipid_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="磷脂 mol 比例最小值") - phospholipid_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="磷脂 mol 比例最大值") - cholesterol_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="胆固醇 mol 比例最小值") - cholesterol_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="胆固醇 mol 比例最大值") - peg_mol_min: float = Field(default=0.00, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最小值") - peg_mol_max: float = Field(default=0.05, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最大值") + """组分范围配置(mol 比例为百分数 0-100)""" + weight_ratio_min: float = Field(default=5.0, ge=1.0, le=50.0, description="阳离子脂质/mRNA 重量比最小值") + weight_ratio_max: float = Field(default=30.0, ge=1.0, le=50.0, description="阳离子脂质/mRNA 重量比最大值") + cationic_mol_min: float = Field(default=5.0, ge=0.0, le=100.0, description="阳离子脂质 mol 比例最小值 (%)") + cationic_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="阳离子脂质 mol 比例最大值 (%)") + phospholipid_mol_min: float = Field(default=0.0, ge=0.0, le=100.0, description="磷脂 mol 比例最小值 (%)") + phospholipid_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="磷脂 mol 比例最大值 (%)") + cholesterol_mol_min: float = Field(default=0.0, ge=0.0, le=100.0, description="胆固醇 mol 比例最小值 (%)") + cholesterol_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="胆固醇 mol 比例最大值 (%)") + peg_mol_min: float = Field(default=0.0, ge=0.0, le=20.0, description="PEG 脂质 mol 比例最小值 (%)") + peg_mol_max: float = Field(default=5.0, ge=0.0, le=20.0, description="PEG 脂质 mol 比例最大值 (%)") def to_comp_ranges(self) -> CompRanges: """转换为 CompRanges 对象""" @@ -87,7 +87,8 @@ class OptimizeRequest(BaseModel): top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations to return") num_seeds: Optional[int] = Field(default=None, ge=1, le=500, description="Number of seed points from first iteration (default: top_k * 5)") top_per_seed: int = Field(default=1, ge=1, le=10, description="Number of local best to keep per seed in refinement") - step_sizes: Optional[List[float]] = Field(default=None, description="Step sizes for each iteration (default: [0.10, 0.02, 0.01])") + step_sizes: Optional[List[float]] = Field(default=None, description="Mol ratio step sizes for each iteration (default: [10, 2, 1])") + wr_step_sizes: Optional[List[float]] = Field(default=None, description="Weight ratio step sizes for each iteration (default: [5, 2, 1])") comp_ranges: Optional[CompRangesRequest] = Field(default=None, description="组分范围配置(默认使用标准范围)") routes: Optional[List[str]] = Field(default=None, description="给药途径列表 (default: ['intravenous', 'intramuscular'])") scoring_weights: Optional[ScoringWeightsRequest] = Field(default=None, description="评分权重配置(默认仅按 biodist 排序)") @@ -290,7 +291,6 @@ async def optimize_formulation(request: OptimizeRequest): scoring_weights = request.scoring_weights.to_scoring_weights() try: - # 执行优化(层级搜索策略) results = optimize( smiles=request.smiles, organ=request.organ, @@ -300,6 +300,7 @@ async def optimize_formulation(request: OptimizeRequest): num_seeds=request.num_seeds, top_per_seed=request.top_per_seed, step_sizes=request.step_sizes, + wr_step_sizes=request.wr_step_sizes, comp_ranges=comp_ranges, routes=request.routes, scoring_weights=scoring_weights, diff --git a/app/app.py b/app/app.py index 2dddec2..dbac7ba 100644 --- a/app/app.py +++ b/app/app.py @@ -144,6 +144,7 @@ def call_optimize_api( num_seeds: int = None, top_per_seed: int = 1, step_sizes: list = None, + wr_step_sizes: list = None, comp_ranges: dict = None, routes: list = None, scoring_weights: dict = None, @@ -156,6 +157,7 @@ def call_optimize_api( "num_seeds": num_seeds, "top_per_seed": top_per_seed, "step_sizes": step_sizes, + "wr_step_sizes": wr_step_sizes, "comp_ranges": comp_ranges, "routes": routes, "scoring_weights": scoring_weights, @@ -354,48 +356,75 @@ def main(): use_custom_steps = st.checkbox( "自定义迭代步长", value=False, - help="默认步长为 [0.10, 0.02, 0.01],共3轮逐步精细化搜索。将某轮步长设为0可减少迭代轮数。", + help="默认 mol ratio 步长 [10, 2, 1](百分数),weight ratio 步长 [5, 2, 1],共3轮。将某轮步长设为0可减少迭代轮数。", ) if use_custom_steps: + st.caption("**Mol ratio 步长 (%)**") col1, col2, col3 = st.columns(3) with col1: step1 = st.number_input( - "第1轮步长", - min_value=0.01, max_value=0.20, value=0.10, - step=0.01, format="%.2f", - help="第1轮为全局粗搜索,步长必须大于0", + "第1轮 mol 步长", + min_value=1, max_value=20, value=10, + step=1, + help="第1轮为全局粗搜索", + key="mol_step1", ) with col2: step2 = st.number_input( - "第2轮步长", - min_value=0.00, max_value=0.10, value=0.02, - step=0.01, format="%.2f", + "第2轮 mol 步长", + min_value=0, max_value=10, value=2, + step=1, help="设为0则只进行1轮搜索", + key="mol_step2", ) with col3: step3 = st.number_input( - "第3轮步长", - min_value=0.00, max_value=0.05, value=0.01, - step=0.01, format="%.2f", + "第3轮 mol 步长", + min_value=0, max_value=5, value=1, + step=1, help="设为0则只进行2轮搜索", + key="mol_step3", ) - # 根据步长值构建实际的 step_sizes 列表 - # step2 为 0 → 只保留 [step1](1轮) - # step3 为 0 → 只保留 [step1, step2](2轮) - # 都不为 0 → [step1, step2, step3](3轮) - if step2 == 0.0: - step_sizes = [step1] - elif step3 == 0.0: - step_sizes = [step1, step2] - else: - step_sizes = [step1, step2, step3] + st.caption("**Weight ratio 步长**") + col1, col2, col3 = st.columns(3) + with col1: + wr_step1 = st.number_input( + "第1轮 WR 步长", + min_value=1.0, max_value=10.0, value=5.0, + step=1.0, format="%.1f", + key="wr_step1", + ) + with col2: + wr_step2 = st.number_input( + "第2轮 WR 步长", + min_value=0.0, max_value=5.0, value=2.0, + step=0.5, format="%.1f", + key="wr_step2", + ) + with col3: + wr_step3 = st.number_input( + "第3轮 WR 步长", + min_value=0.0, max_value=2.0, value=1.0, + step=0.5, format="%.1f", + key="wr_step3", + ) - # 显示实际迭代轮数提示 - st.caption(f"📌 实际迭代轮数: {len(step_sizes)} 轮,步长: {step_sizes}") + if step2 == 0: + step_sizes = [float(step1)] + wr_step_sizes_val = [wr_step1] + elif step3 == 0: + step_sizes = [float(step1), float(step2)] + wr_step_sizes_val = [wr_step1, wr_step2] + else: + step_sizes = [float(step1), float(step2), float(step3)] + wr_step_sizes_val = [wr_step1, wr_step2, wr_step3] + + st.caption(f"📌 实际迭代轮数: {len(step_sizes)} 轮,mol步长: {step_sizes},WR步长: {wr_step_sizes_val}") else: - step_sizes = None # 使用默认值 + step_sizes = None + wr_step_sizes_val = None st.markdown("**组分范围限制**") use_custom_ranges = st.checkbox( @@ -408,37 +437,37 @@ def main(): st.caption("阳离子脂质/mRNA 重量比") col1, col2 = st.columns(2) with col1: - weight_ratio_min = st.number_input("最小", min_value=0.01, max_value=0.50, value=0.05, step=0.01, format="%.2f", key="wr_min") + weight_ratio_min = st.number_input("最小", min_value=1.0, max_value=50.0, value=5.0, step=1.0, format="%.1f", key="wr_min") with col2: - weight_ratio_max = st.number_input("最大", min_value=0.01, max_value=0.50, value=0.30, step=0.01, format="%.2f", key="wr_max") + weight_ratio_max = st.number_input("最大", min_value=1.0, max_value=50.0, value=30.0, step=1.0, format="%.1f", key="wr_max") - st.caption("阳离子脂质 mol 比例") + st.caption("阳离子脂质 mol 比例 (%)") col1, col2 = st.columns(2) with col1: - cationic_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.05, step=0.05, format="%.2f", key="cat_min") + cationic_mol_min = st.number_input("最小", min_value=0.0, max_value=100.0, value=5.0, step=5.0, format="%.1f", key="cat_min") with col2: - cationic_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="cat_max") + cationic_mol_max = st.number_input("最大", min_value=0.0, max_value=100.0, value=80.0, step=5.0, format="%.1f", key="cat_max") - st.caption("磷脂 mol 比例") + st.caption("磷脂 mol 比例 (%)") col1, col2 = st.columns(2) with col1: - phospholipid_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="phos_min") + phospholipid_mol_min = st.number_input("最小", min_value=0.0, max_value=100.0, value=0.0, step=5.0, format="%.1f", key="phos_min") with col2: - phospholipid_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="phos_max") + phospholipid_mol_max = st.number_input("最大", min_value=0.0, max_value=100.0, value=80.0, step=5.0, format="%.1f", key="phos_max") - st.caption("胆固醇 mol 比例") + st.caption("胆固醇 mol 比例 (%)") col1, col2 = st.columns(2) with col1: - cholesterol_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="chol_min") + cholesterol_mol_min = st.number_input("最小", min_value=0.0, max_value=100.0, value=0.0, step=5.0, format="%.1f", key="chol_min") with col2: - cholesterol_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="chol_max") + cholesterol_mol_max = st.number_input("最大", min_value=0.0, max_value=100.0, value=80.0, step=5.0, format="%.1f", key="chol_max") - st.caption("PEG 脂质 mol 比例") + st.caption("PEG 脂质 mol 比例 (%)") col1, col2 = st.columns(2) with col1: - peg_mol_min = st.number_input("最小", min_value=0.00, max_value=0.20, value=0.00, step=0.01, format="%.2f", key="peg_min") + peg_mol_min = st.number_input("最小", min_value=0.0, max_value=20.0, value=0.0, step=1.0, format="%.1f", key="peg_min") with col2: - peg_mol_max = st.number_input("最大", min_value=0.00, max_value=0.20, value=0.05, step=0.01, format="%.2f", key="peg_max") + peg_mol_max = st.number_input("最大", min_value=0.0, max_value=20.0, value=5.0, step=1.0, format="%.1f", key="peg_max") comp_ranges = { "weight_ratio_min": weight_ratio_min, @@ -453,13 +482,12 @@ def main(): "peg_mol_max": peg_mol_max, } - # 简单验证 min_sum = cationic_mol_min + phospholipid_mol_min + cholesterol_mol_min + peg_mol_min max_sum = cationic_mol_max + phospholipid_mol_max + cholesterol_mol_max + peg_mol_max - if min_sum > 1.0 or max_sum < 1.0: + if min_sum > 100.0 or max_sum < 100.0: st.warning("⚠️ 当前范围设置可能无法生成有效配方(mol 比例需加起来为 100%)") else: - comp_ranges = None # 使用默认值 + comp_ranges = None st.markdown("**评分/排序权重**") use_custom_scoring = st.checkbox( @@ -575,6 +603,7 @@ def main(): num_seeds=num_seeds, top_per_seed=top_per_seed, step_sizes=step_sizes, + wr_step_sizes=wr_step_sizes_val, comp_ranges=comp_ranges, routes=selected_routes, scoring_weights=scoring_weights, diff --git a/app/optimize.py b/app/optimize.py index 8c80937..6d2d705 100644 --- a/app/optimize.py +++ b/app/optimize.py @@ -44,22 +44,22 @@ AVAILABLE_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", @dataclass class CompRanges: - """组分参数范围配置""" + """组分参数范围配置(mol 比例为百分数 0-100)""" # 阳离子脂质/mRNA 重量比 - weight_ratio_min: float = 0.05 - weight_ratio_max: float = 0.30 - # 阳离子脂质 mol 比例 - cationic_mol_min: float = 0.05 - cationic_mol_max: float = 0.80 - # 磷脂 mol 比例 - phospholipid_mol_min: float = 0.00 - phospholipid_mol_max: float = 0.80 - # 胆固醇 mol 比例 - cholesterol_mol_min: float = 0.00 - cholesterol_mol_max: float = 0.80 - # PEG 脂质 mol 比例 - peg_mol_min: float = 0.00 - peg_mol_max: float = 0.05 + weight_ratio_min: float = 5.0 + weight_ratio_max: float = 30.0 + # 阳离子脂质 mol 比例 (%) + cationic_mol_min: float = 5.0 + cationic_mol_max: float = 80.0 + # 磷脂 mol 比例 (%) + phospholipid_mol_min: float = 0.0 + phospholipid_mol_max: float = 80.0 + # 胆固醇 mol 比例 (%) + cholesterol_mol_min: float = 0.0 + cholesterol_mol_max: float = 80.0 + # PEG 脂质 mol 比例 (%) + peg_mol_min: float = 0.0 + peg_mol_max: float = 5.0 def to_dict(self) -> Dict: """转换为字典""" @@ -94,10 +94,10 @@ class CompRanges: min_sum = self.cationic_mol_min + self.phospholipid_mol_min + self.cholesterol_mol_min + self.peg_mol_min max_sum = self.cationic_mol_max + self.phospholipid_mol_max + self.cholesterol_mol_max + self.peg_mol_max - if min_sum > 1.0: - return f"mol比例最小值之和({min_sum:.2f})超过100%,无法生成有效配方" - if max_sum < 1.0: - return f"mol比例最大值之和({max_sum:.2f})不足100%,无法生成有效配方" + if min_sum > 100.0: + return f"mol比例最小值之和({min_sum:.1f}%)超过100%,无法生成有效配方" + if max_sum < 100.0: + return f"mol比例最大值之和({max_sum:.1f}%)不足100%,无法生成有效配方" return None @@ -109,11 +109,14 @@ class CompRanges: # 默认组分范围 DEFAULT_COMP_RANGES = CompRanges() -# 最小 step size -MIN_STEP_SIZE = 0.01 +# PEG 最小 step size (百分数) +MIN_STEP_SIZE = 1 -# 迭代策略:每个迭代的 step_size -ITERATION_STEP_SIZES = [0.10, 0.02, 0.01] +# 迭代策略:mol ratio 的 step_size (百分数) +MOL_STEP_SIZES = [10, 2, 1] + +# 迭代策略:weight ratio 的 step_size(与 mol ratio 解耦) +WR_STEP_SIZES = [5, 2, 1] # Helper lipid 选项(不包含 DOTAP) HELPER_LIPID_OPTIONS = ["DOPE", "DSPC"] @@ -343,51 +346,46 @@ def generate_grid_values( def generate_initial_grid( - step_size: float, + mol_step: float, + wr_step: float, comp_ranges: CompRanges = None, ) -> List[Tuple[float, float, float, float, float]]: """ - 生成初始搜索网格(满足 mol ratio 和为 1 的约束)。 + 生成初始搜索网格(满足 mol ratio 和为 100% 的约束)。 Args: - step_size: 搜索步长 + mol_step: mol ratio 搜索步长 (百分数) + wr_step: weight ratio 搜索步长 comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) Returns: - List of (cationic_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol) + List of (weight_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol) """ if comp_ranges is None: comp_ranges = DEFAULT_COMP_RANGES grid = [] - # Cationic_Lipid_to_mRNA_weight_ratio weight_ratios = np.arange( comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max + 0.001, - step_size + wr_step ) - # PEG: 单独处理,范围很小,始终用最小步长 peg_values = np.arange( comp_ranges.peg_mol_min, comp_ranges.peg_mol_max + 0.001, MIN_STEP_SIZE ) - # 其他三个 mol ratio 需要满足和为 1 - PEG - mol_step = step_size - for weight_ratio in weight_ratios: for peg in peg_values: - remaining = 1.0 - peg - # 生成满足约束的组合 + remaining = 100.0 - peg cationic_max = min(comp_ranges.cationic_mol_max, remaining) + 0.001 for cationic_mol in np.arange(comp_ranges.cationic_mol_min, cationic_max, mol_step): phospholipid_max = min(comp_ranges.phospholipid_mol_max, remaining - cationic_mol) + 0.001 for phospholipid_mol in np.arange(comp_ranges.phospholipid_mol_min, phospholipid_max, mol_step): cholesterol_mol = remaining - cationic_mol - phospholipid_mol - # 检查约束 if (comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max): grid.append(( round(weight_ratio, 4), @@ -402,7 +400,8 @@ def generate_initial_grid( def generate_refined_grid( seeds: List[Formulation], - step_size: float, + mol_step: float, + wr_step: float, radius: int = 2, comp_ranges: CompRanges = None, ) -> List[Tuple[float, float, float, float, float]]: @@ -411,7 +410,8 @@ def generate_refined_grid( Args: seeds: 种子配方列表 - step_size: 步长 + mol_step: mol ratio 步长 (百分数) + wr_step: weight ratio 步长 radius: 扩展半径 comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) @@ -424,35 +424,31 @@ def generate_refined_grid( grid_set = set() for seed in seeds: - # Weight ratio weight_ratios = generate_grid_values( - seed.cationic_lipid_to_mrna_ratio, step_size, + seed.cationic_lipid_to_mrna_ratio, wr_step, comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius ) - # PEG (始终用最小步长) peg_values = generate_grid_values( seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius ) - # Mol ratios cationic_mols = generate_grid_values( - seed.cationic_lipid_mol_ratio, step_size, + seed.cationic_lipid_mol_ratio, mol_step, comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius ) phospholipid_mols = generate_grid_values( - seed.phospholipid_mol_ratio, step_size, + seed.phospholipid_mol_ratio, mol_step, comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius ) for weight_ratio in weight_ratios: for peg in peg_values: - remaining = 1.0 - peg + remaining = 100.0 - peg for cationic_mol in cationic_mols: for phospholipid_mol in phospholipid_mols: cholesterol_mol = remaining - cationic_mol - phospholipid_mol - # 检查约束 if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and @@ -677,7 +673,8 @@ def select_top_k( def generate_single_seed_grid( seed: Formulation, - step_size: float, + mol_step: float, + wr_step: float, radius: int = 2, comp_ranges: CompRanges = None, ) -> List[Tuple[float, float, float, float, float]]: @@ -686,7 +683,8 @@ def generate_single_seed_grid( Args: seed: 种子配方 - step_size: 步长 + mol_step: mol ratio 步长 (百分数) + wr_step: weight ratio 步长 radius: 扩展半径 comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) @@ -698,35 +696,31 @@ def generate_single_seed_grid( grid_set = set() - # Weight ratio weight_ratios = generate_grid_values( - seed.cationic_lipid_to_mrna_ratio, step_size, + seed.cationic_lipid_to_mrna_ratio, wr_step, comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius ) - # PEG (始终用最小步长) peg_values = generate_grid_values( seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius ) - # Mol ratios cationic_mols = generate_grid_values( - seed.cationic_lipid_mol_ratio, step_size, + seed.cationic_lipid_mol_ratio, mol_step, comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius ) phospholipid_mols = generate_grid_values( - seed.phospholipid_mol_ratio, step_size, + seed.phospholipid_mol_ratio, mol_step, comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius ) for weight_ratio in weight_ratios: for peg in peg_values: - remaining = 1.0 - peg + remaining = 100.0 - peg for cationic_mol in cationic_mols: for phospholipid_mol in phospholipid_mols: cholesterol_mol = remaining - cationic_mol - phospholipid_mol - # 检查约束 if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and @@ -751,6 +745,7 @@ def optimize( num_seeds: Optional[int] = None, top_per_seed: int = 1, step_sizes: Optional[List[float]] = None, + wr_step_sizes: Optional[List[float]] = None, comp_ranges: Optional[CompRanges] = None, routes: Optional[List[str]] = None, scoring_weights: Optional[ScoringWeights] = None, @@ -772,7 +767,8 @@ def optimize( top_k: 最终返回的最优配方数 num_seeds: 第一次迭代后保留的种子点数量(默认为 top_k * 5) top_per_seed: 每个种子点的邻域搜索后保留的局部最优点数量 - step_sizes: 每轮迭代的步长列表(默认为 [0.10, 0.02, 0.01]) + step_sizes: mol ratio 每轮迭代的步长列表 (百分数,默认 [10, 2, 1]) + wr_step_sizes: weight ratio 每轮迭代的步长列表 (默认 [5, 2, 1]) comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) routes: 给药途径列表(默认使用 ROUTE_OPTIONS) scoring_weights: 评分权重配置(默认仅按 biodist 排序) @@ -781,49 +777,52 @@ def optimize( Returns: 最终 top-k 配方列表 """ - # 默认 num_seeds 为 top_k * 5 if num_seeds is None: num_seeds = top_k * 5 - # 默认步长 if step_sizes is None: - step_sizes = ITERATION_STEP_SIZES + step_sizes = MOL_STEP_SIZES + + if wr_step_sizes is None: + wr_step_sizes = WR_STEP_SIZES + + # 两组步长长度必须一致 + if len(wr_step_sizes) != len(step_sizes): + raise ValueError( + f"step_sizes ({len(step_sizes)}) 和 wr_step_sizes ({len(wr_step_sizes)}) 长度不一致" + ) - # 默认组分范围 if comp_ranges is None: comp_ranges = DEFAULT_COMP_RANGES - # 默认给药途径 if routes is None: routes = ROUTE_OPTIONS - # 默认评分权重 if scoring_weights is None: scoring_weights = DEFAULT_SCORING_WEIGHTS - # 评分函数(用于 Formulation 对象排序) def _score(f: Formulation) -> float: return compute_formulation_score(f, organ, scoring_weights) logger.info(f"Starting optimization for organ: {organ}") logger.info(f"SMILES: {smiles}") logger.info(f"Strategy: num_seeds={num_seeds}, top_per_seed={top_per_seed}, top_k={top_k}") - logger.info(f"Step sizes: {step_sizes}") + logger.info(f"Mol step sizes: {step_sizes}, WR step sizes: {wr_step_sizes}") logger.info(f"Routes: {routes}") logger.info(f"Scoring weights: biodist={scoring_weights.biodist_weight}, delivery={scoring_weights.delivery_weight}, size={scoring_weights.size_weight}") logger.info(f"Comp ranges: {comp_ranges.to_dict()}") seeds = None - for iteration, step_size in enumerate(step_sizes): + for iteration, (mol_step, wr_step) in enumerate(zip(step_sizes, wr_step_sizes)): logger.info(f"\n{'='*60}") - logger.info(f"Iteration {iteration + 1}/{len(step_sizes)}, step_size={step_size}") + logger.info(f"Iteration {iteration + 1}/{len(step_sizes)}, mol_step={mol_step}, wr_step={wr_step}") logger.info(f"{'='*60}") if seeds is None: # ==================== 第一次迭代:全局稀疏搜索 ==================== logger.info("Generating initial grid (global sparse search)...") - grid = generate_initial_grid(step_size, comp_ranges) + grid = generate_initial_grid(mol_step, wr_step, comp_ranges) logger.info(f"Grid size: {len(grid)} comp combinations") @@ -853,8 +852,7 @@ def optimize( all_local_best = [] for seed_idx, seed in enumerate(seeds): - # 为当前种子点生成邻域网格 - local_grid = generate_single_seed_grid(seed, step_size, radius=2, comp_ranges=comp_ranges) + local_grid = generate_single_seed_grid(seed, mol_step, wr_step, radius=2, comp_ranges=comp_ranges) if len(local_grid) == 0: # 如果没有新的网格点,保留原种子 @@ -946,7 +944,8 @@ def main( top_k: int = typer.Option(20, "--top-k", "-k", help="Number of top formulations to return"), num_seeds: Optional[int] = typer.Option(None, "--num-seeds", "-n", help="Number of seed points from first iteration (default: top_k * 5)"), top_per_seed: int = typer.Option(1, "--top-per-seed", "-t", help="Number of local best to keep per seed"), - step_sizes: Optional[str] = typer.Option(None, "--step-sizes", "-S", help="Comma-separated step sizes (e.g., '0.10,0.02,0.01')"), + step_sizes: Optional[str] = typer.Option(None, "--step-sizes", "-S", help="Mol ratio step sizes, comma-separated (e.g., '10,2,1')"), + wr_step_sizes: Optional[str] = typer.Option(None, "--wr-step-sizes", help="Weight ratio step sizes, comma-separated (e.g., '5,2,1')"), batch_size: int = typer.Option(256, "--batch-size", "-b", help="Prediction batch size"), device: str = typer.Option("cuda" if torch.cuda.is_available() else "cpu", "--device", "-d", help="Device"), ): @@ -961,7 +960,7 @@ def main( 示例: python -m app.optimize --smiles "CC(C)..." --organ liver python -m app.optimize -s "CC(C)..." -o spleen -k 10 -n 30 -t 2 - python -m app.optimize -s "CC(C)..." -o liver -S "0.10,0.05,0.02" + python -m app.optimize -s "CC(C)..." -o liver -S "10,2,1" --wr-step-sizes "5,2,1" """ # 验证器官 if organ not in AVAILABLE_ORGANS: @@ -977,12 +976,19 @@ def main( logger.error(f"Invalid step sizes format: {step_sizes}") raise typer.Exit(1) + parsed_wr_step_sizes = None + if wr_step_sizes: + try: + parsed_wr_step_sizes = [float(s.strip()) for s in wr_step_sizes.split(",")] + except ValueError: + logger.error(f"Invalid wr step sizes format: {wr_step_sizes}") + raise typer.Exit(1) + # 加载模型 logger.info(f"Loading model from {model_path}") device = torch.device(device) model = load_model(model_path, device) - # 执行优化(层级搜索策略) results = optimize( smiles=smiles, organ=organ, @@ -992,6 +998,7 @@ def main( num_seeds=num_seeds, top_per_seed=top_per_seed, step_sizes=parsed_step_sizes, + wr_step_sizes=parsed_wr_step_sizes, batch_size=batch_size, ) diff --git a/scripts/preprocess_internal.py b/scripts/preprocess_internal.py index cc3f200..2e54998 100644 --- a/scripts/preprocess_internal.py +++ b/scripts/preprocess_internal.py @@ -37,10 +37,6 @@ def main( .transform(lambda x: (x - x.mean()) / x.std()) ) - # 将 Cationic_Lipid_Mol_Ratio,Phospholipid_Mol_Ratio,Cholesterol_Mol_Ratio,PEG_Lipid_Mol_Ratio 四列的百分数转换为小数 - logger.info("Converting percentage columns to decimal...") - df[["Cationic_Lipid_Mol_Ratio", "Phospholipid_Mol_Ratio", "Cholesterol_Mol_Ratio", "PEG_Lipid_Mol_Ratio"]] = df[["Cationic_Lipid_Mol_Ratio", "Phospholipid_Mol_Ratio", "Cholesterol_Mol_Ratio", "PEG_Lipid_Mol_Ratio"]] / 100 - # 对 size 列取 log logger.info("Log-transforming size column...") df["size"] = pd.to_numeric(df["size"], errors="coerce")