""" Streamlit 配方优化交互界面 启动应用: streamlit run app/app.py Docker 环境变量: API_URL: API 服务地址 (默认: http://localhost:8000) """ import io import os from datetime import datetime import httpx import pandas as pd import streamlit as st # ============ 配置 ============ # 从环境变量读取 API 地址,支持 Docker 环境 API_URL = os.environ.get("API_URL", "http://localhost:8000") AVAILABLE_ORGANS = [ "liver", "spleen", "lung", "heart", "kidney", "muscle", "lymph_nodes", ] ORGAN_LABELS = { "liver": "肝脏 (Liver)", "spleen": "脾脏 (Spleen)", "lung": "肺 (Lung)", "heart": "心脏 (Heart)", "kidney": "肾脏 (Kidney)", "muscle": "肌肉 (Muscle)", "lymph_nodes": "淋巴结 (Lymph Nodes)", } AVAILABLE_ROUTES = [ "intravenous", "intramuscular", ] ROUTE_LABELS = { "intravenous": "静脉注射 (Intravenous)", "intramuscular": "肌肉注射 (Intramuscular)", } # ============ 页面配置 ============ st.set_page_config( page_title="LNP 配方优化", page_icon="🧬", layout="wide", initial_sidebar_state="expanded", ) # ============ 自定义样式 ============ st.markdown(""" """, unsafe_allow_html=True) # ============ 辅助函数 ============ def check_api_status() -> bool: """检查 API 状态""" try: with httpx.Client(timeout=5) as client: response = client.get(f"{API_URL}/") return response.status_code == 200 except: return False def call_optimize_api( smiles: str, organ: str, top_k: int = 20, num_seeds: int = None, top_per_seed: int = 1, step_sizes: list = None, wr_step_sizes: list = None, comp_ranges: dict = None, routes: list = None, scoring_weights: dict = None, ) -> dict: """调用优化 API""" payload = { "smiles": smiles, "organ": organ, "top_k": top_k, "num_seeds": num_seeds, "top_per_seed": top_per_seed, "step_sizes": step_sizes, "wr_step_sizes": wr_step_sizes, "comp_ranges": comp_ranges, "routes": routes, "scoring_weights": scoring_weights, } with httpx.Client(timeout=600) as client: # 10 分钟超时(自定义参数可能需要更长时间) response = client.post( f"{API_URL}/optimize", json=payload, ) response.raise_for_status() return response.json() # PDI 分类标签 PDI_CLASS_LABELS = { 0: "<0.2 (优)", 1: "0.2-0.3 (良)", 2: "0.3-0.4 (中)", 3: ">0.4 (差)", } # EE 分类标签 EE_CLASS_LABELS = { 0: "<50% (低)", 1: "50-80% (中)", 2: ">80% (高)", } # 毒性分类标签 TOXIC_CLASS_LABELS = { 0: "无毒 ✓", 1: "有毒 ⚠", } def format_results_dataframe(results: dict, smiles_label: str = None) -> pd.DataFrame: """将 API 结果转换为 DataFrame""" formulations = results["formulations"] target_organ = results["target_organ"] rows = [] for f in formulations: row = {} # 如果有 SMILES 标签,添加到首列 if smiles_label: row["SMILES"] = smiles_label row.update({ "排名": f["rank"], }) # 如果有综合评分,显示在排名后面 if f.get("composite_score") is not None: row["综合评分"] = f"{f['composite_score']:.4f}" row.update({ f"{target_organ}分布": f"{f['target_biodist']*100:.8f}%", "阳离子脂质/mRNA比例": f["cationic_lipid_to_mrna_ratio"], "阳离子脂质(mol)比例": f["cationic_lipid_mol_ratio"], "磷脂(mol)比例": f["phospholipid_mol_ratio"], "胆固醇(mol)比例": f["cholesterol_mol_ratio"], "PEG脂质(mol)比例": f["peg_lipid_mol_ratio"], "辅助脂质": f["helper_lipid"], "给药途径": f["route"], }) # 添加额外预测值 if f.get("quantified_delivery") is not None: row["量化递送"] = f"{f['quantified_delivery']:.4f}" if f.get("unnormalized_delivery") is not None: row["总荧光强度"] = f"{f['unnormalized_delivery']:.4f}" if f.get("size") is not None: row["粒径(nm)"] = f"{f['size']:.1f}" if f.get("pdi_class") is not None: row["PDI"] = PDI_CLASS_LABELS.get(f["pdi_class"], str(f["pdi_class"])) if f.get("ee_class") is not None: row["包封率"] = EE_CLASS_LABELS.get(f["ee_class"], str(f["ee_class"])) if f.get("toxic_class") is not None: row["毒性"] = TOXIC_CLASS_LABELS.get(f["toxic_class"], str(f["toxic_class"])) # 添加其他器官的 biodist for organ, value in f["all_biodist"].items(): if organ != target_organ: row[f"{organ}分布"] = f"{value*100:.2f}%" rows.append(row) return pd.DataFrame(rows) def create_export_csv(df: pd.DataFrame, smiles: str, organ: str) -> str: """创建导出用的 CSV 内容""" # 添加元信息 meta_info = f"# LNP 配方优化结果\n# SMILES: {smiles}\n# 目标器官: {organ}\n# 导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" csv_content = df.to_csv(index=False) return meta_info + csv_content # ============ 主界面 ============ def main(): # 标题 st.markdown('
基于深度学习的脂质纳米颗粒配方智能优选
', unsafe_allow_html=True) # 检查 API 状态 api_online = check_api_status() # ========== 侧边栏 ========== with st.sidebar: # st.header("⚙️ 参数设置") # API 状态 if api_online: st.success("🟢 API 服务在线") else: st.error("🔴 API 服务离线") st.info("请先启动 API 服务:\n```\nuvicorn app.api:app --port 8000\n```") # st.divider() # SMILES 输入 st.subheader("🔬 分子结构") smiles_input = st.text_area( "输入阳离子脂质 SMILES", value="", height=100, placeholder="例如: CC(C)NCCNC(C)C\n多条SMILES用英文逗号分隔: SMI1,SMI2,SMI3", help="输入阳离子脂质的 SMILES 字符串。支持多条 SMILES,用英文逗号 (,) 分隔", ) # 示例 SMILES # with st.expander("📋 示例 SMILES"): # example_smiles = { # "DLin-MC3-DMA": "CC(C)=CCCC(C)=CCCC(C)=CCN(C)CCCCCCCCOC(=O)CCCCCCC/C=C\\CCCCCCCC", # "简单胺": "CC(C)NCCNC(C)C", # "长链胺": "CCCCCCCCCCCCNCCNCCCCCCCCCCCC", # } # for name, smi in example_smiles.items(): # if st.button(f"使用 {name}", key=f"example_{name}"): # st.session_state["smiles_input"] = smi # st.rerun() # st.divider() # 目标器官选择 st.subheader("🎯 目标器官") selected_organ = st.selectbox( "选择优化目标器官", options=AVAILABLE_ORGANS, format_func=lambda x: ORGAN_LABELS.get(x, x), index=0, ) # 给药途径选择 st.subheader("💉 给药途径") selected_routes = st.multiselect( "选择给药途径", options=AVAILABLE_ROUTES, default=AVAILABLE_ROUTES, format_func=lambda x: ROUTE_LABELS.get(x, x), help="选择要搜索的给药途径,可多选。至少选择一种。", ) if not selected_routes: st.warning("⚠️ 请至少选择一种给药途径") # 高级选项 with st.expander("🔧 高级选项"): st.markdown("**输出设置**") top_k = st.slider( "返回配方数量 (top_k)", min_value=5, max_value=100, value=20, step=5, help="最终返回的最优配方数量", ) st.markdown("**搜索策略**") num_seeds = st.slider( "种子点数量 (num_seeds)", min_value=10, max_value=200, value=top_k * 5, step=10, help="第一轮迭代后保留的种子点数量,更多种子点意味着更广泛的搜索", ) top_per_seed = st.slider( "每个种子的局部最优数 (top_per_seed)", min_value=1, max_value=5, value=1, step=1, help="后续迭代中,每个种子点邻域保留的局部最优数量", ) st.markdown("**迭代步长与轮数**") use_custom_steps = st.checkbox( "自定义迭代步长", value=False, help="默认 mol ratio 步长 [10, 2, 1](百分数),weight ratio 步长 [5, 2, 1],共3轮。将某轮步长设为0可减少迭代轮数。", ) if use_custom_steps: st.caption("**Mol ratio 步长 (%)**") col1, col2, col3 = st.columns(3) with col1: step1 = st.number_input( "第1轮 mol 步长", min_value=1, max_value=20, value=10, step=1, help="第1轮为全局粗搜索", key="mol_step1", ) with col2: step2 = st.number_input( "第2轮 mol 步长", min_value=0, max_value=10, value=2, step=1, help="设为0则只进行1轮搜索", key="mol_step2", ) with col3: step3 = st.number_input( "第3轮 mol 步长", min_value=0, max_value=5, value=1, step=1, help="设为0则只进行2轮搜索", key="mol_step3", ) st.caption("**Weight ratio 步长**") col1, col2, col3 = st.columns(3) with col1: wr_step1 = st.number_input( "第1轮 WR 步长", min_value=1.0, max_value=10.0, value=5.0, step=1.0, format="%.1f", key="wr_step1", ) with col2: wr_step2 = st.number_input( "第2轮 WR 步长", min_value=0.0, max_value=5.0, value=2.0, step=0.5, format="%.1f", key="wr_step2", ) with col3: wr_step3 = st.number_input( "第3轮 WR 步长", min_value=0.0, max_value=2.0, value=1.0, step=0.5, format="%.1f", key="wr_step3", ) if step2 == 0: step_sizes = [float(step1)] wr_step_sizes_val = [wr_step1] elif step3 == 0: step_sizes = [float(step1), float(step2)] wr_step_sizes_val = [wr_step1, wr_step2] else: step_sizes = [float(step1), float(step2), float(step3)] wr_step_sizes_val = [wr_step1, wr_step2, wr_step3] st.caption(f"📌 实际迭代轮数: {len(step_sizes)} 轮,mol步长: {step_sizes},WR步长: {wr_step_sizes_val}") else: step_sizes = None wr_step_sizes_val = None st.markdown("**组分范围限制**") use_custom_ranges = st.checkbox( "自定义组分取值范围", value=False, help="限制各组分的取值范围(mol 比例加起来仍为 100%)", ) if use_custom_ranges: st.caption("阳离子脂质/mRNA 重量比") col1, col2 = st.columns(2) with col1: weight_ratio_min = st.number_input("最小", min_value=1.0, max_value=50.0, value=5.0, step=1.0, format="%.1f", key="wr_min") with col2: weight_ratio_max = st.number_input("最大", min_value=1.0, max_value=50.0, value=30.0, step=1.0, format="%.1f", key="wr_max") st.caption("阳离子脂质 mol 比例 (%)") col1, col2 = st.columns(2) with col1: cationic_mol_min = st.number_input("最小", min_value=0.0, max_value=100.0, value=5.0, step=5.0, format="%.1f", key="cat_min") with col2: cationic_mol_max = st.number_input("最大", min_value=0.0, max_value=100.0, value=80.0, step=5.0, format="%.1f", key="cat_max") st.caption("磷脂 mol 比例 (%)") col1, col2 = st.columns(2) with col1: phospholipid_mol_min = st.number_input("最小", min_value=0.0, max_value=100.0, value=0.0, step=5.0, format="%.1f", key="phos_min") with col2: phospholipid_mol_max = st.number_input("最大", min_value=0.0, max_value=100.0, value=80.0, step=5.0, format="%.1f", key="phos_max") st.caption("胆固醇 mol 比例 (%)") col1, col2 = st.columns(2) with col1: cholesterol_mol_min = st.number_input("最小", min_value=0.0, max_value=100.0, value=0.0, step=5.0, format="%.1f", key="chol_min") with col2: cholesterol_mol_max = st.number_input("最大", min_value=0.0, max_value=100.0, value=80.0, step=5.0, format="%.1f", key="chol_max") st.caption("PEG 脂质 mol 比例 (%)") col1, col2 = st.columns(2) with col1: peg_mol_min = st.number_input("最小", min_value=0.0, max_value=20.0, value=0.0, step=1.0, format="%.1f", key="peg_min") with col2: peg_mol_max = st.number_input("最大", min_value=0.0, max_value=20.0, value=5.0, step=1.0, format="%.1f", key="peg_max") comp_ranges = { "weight_ratio_min": weight_ratio_min, "weight_ratio_max": weight_ratio_max, "cationic_mol_min": cationic_mol_min, "cationic_mol_max": cationic_mol_max, "phospholipid_mol_min": phospholipid_mol_min, "phospholipid_mol_max": phospholipid_mol_max, "cholesterol_mol_min": cholesterol_mol_min, "cholesterol_mol_max": cholesterol_mol_max, "peg_mol_min": peg_mol_min, "peg_mol_max": peg_mol_max, } min_sum = cationic_mol_min + phospholipid_mol_min + cholesterol_mol_min + peg_mol_min max_sum = cationic_mol_max + phospholipid_mol_max + cholesterol_mol_max + peg_mol_max if min_sum > 100.0 or max_sum < 100.0: st.warning("⚠️ 当前范围设置可能无法生成有效配方(mol 比例需加起来为 100%)") else: comp_ranges = None st.markdown("**评分/排序权重**") use_custom_scoring = st.checkbox( "自定义评分权重", value=False, help="默认仅按目标器官分布排序。开启后可自定义多目标加权评分,总分 = 各项score之和。", ) if use_custom_scoring: st.caption("**回归任务权重**") sw_biodist = st.number_input( "器官分布 (Biodistribution)", min_value=0.00, max_value=10.00, value=0.30, step=0.05, format="%.2f", key="sw_biodist", help="score = biodist_value × weight", ) sw_delivery = st.number_input( "量化递送 (Quantified Delivery)", min_value=0.00, max_value=10.00, value=0.25, step=0.05, format="%.2f", key="sw_delivery", help="score = normalize(delivery, route) × weight", ) sw_size = st.number_input( "粒径 (Size, 80-150nm)", min_value=0.00, max_value=10.00, value=0.05, step=0.05, format="%.2f", key="sw_size", help="score = (1 if 60≤size≤150 else 0) × weight", ) st.caption("**包封率 (EE) 分类权重**") col1, col2, col3 = st.columns(3) with col1: sw_ee0 = st.number_input("<50% (低)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_ee0") with col2: sw_ee1 = st.number_input("50-80% (中)", min_value=0.00, max_value=1.00, value=0.02, step=0.01, format="%.2f", key="sw_ee1") with col3: sw_ee2 = st.number_input(">80% (高)", min_value=0.00, max_value=1.00, value=0.08, step=0.01, format="%.2f", key="sw_ee2") st.caption("**PDI 分类权重**") col1, col2, col3, col4 = st.columns(4) with col1: sw_pdi0 = st.number_input("<0.2 (优)", min_value=0.00, max_value=1.00, value=0.08, step=0.01, format="%.2f", key="sw_pdi0") with col2: sw_pdi1 = st.number_input("0.2-0.3 (良)", min_value=0.00, max_value=1.00, value=0.02, step=0.01, format="%.2f", key="sw_pdi1") with col3: sw_pdi2 = st.number_input("0.3-0.4 (中)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_pdi2") with col4: sw_pdi3 = st.number_input(">0.4 (差)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_pdi3") st.caption("**毒性分类权重**") col1, col2 = st.columns(2) with col1: sw_toxic0 = st.number_input("无毒", min_value=0.00, max_value=1.00, value=0.20, step=0.05, format="%.2f", key="sw_toxic0") with col2: sw_toxic1 = st.number_input("有毒", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="sw_toxic1") scoring_weights = { "biodist_weight": sw_biodist, "delivery_weight": sw_delivery, "size_weight": sw_size, "ee_class_weights": [sw_ee0, sw_ee1, sw_ee2], "pdi_class_weights": [sw_pdi0, sw_pdi1, sw_pdi2, sw_pdi3], "toxic_class_weights": [sw_toxic0, sw_toxic1], } else: scoring_weights = None # 使用默认值(仅按 biodist 排序) st.divider() # 优化按钮 optimize_button = st.button( "🚀 开始配方优选", type="primary", use_container_width=True, disabled=not api_online or not smiles_input.strip() or not selected_routes, ) # ========== 主内容区 ========== # 使用 session state 存储结果 if "results" not in st.session_state: st.session_state["results"] = None if "results_df" not in st.session_state: st.session_state["results_df"] = None # 执行优化 if optimize_button and smiles_input.strip(): # 解析多条 SMILES(用逗号分隔) smiles_list = [s.strip() for s in smiles_input.split(",") if s.strip()] if not smiles_list: st.error("❌ 请输入有效的 SMILES 字符串") else: is_multi_smiles = len(smiles_list) > 1 all_results = [] all_dfs = [] errors = [] # 进度条 progress_bar = st.progress(0) status_text = st.empty() for idx, smiles in enumerate(smiles_list): status_text.text(f"🔄 正在优化 SMILES {idx + 1}/{len(smiles_list)}...") progress_bar.progress((idx) / len(smiles_list)) try: results = call_optimize_api( smiles=smiles, organ=selected_organ, top_k=top_k, num_seeds=num_seeds, top_per_seed=top_per_seed, step_sizes=step_sizes, wr_step_sizes=wr_step_sizes_val, comp_ranges=comp_ranges, routes=selected_routes, scoring_weights=scoring_weights, ) all_results.append({"smiles": smiles, "results": results}) # 为多 SMILES 模式添加 SMILES 标签 smiles_label = smiles[:30] + "..." if len(smiles) > 30 else smiles df = format_results_dataframe(results, smiles_label if is_multi_smiles else None) all_dfs.append(df) except httpx.HTTPStatusError as e: try: error_detail = e.response.json().get("detail", str(e)) except: error_detail = str(e) errors.append(f"SMILES {idx + 1}: {error_detail}") except httpx.RequestError as e: errors.append(f"SMILES {idx + 1}: API 连接失败 - {e}") except Exception as e: errors.append(f"SMILES {idx + 1}: {e}") progress_bar.progress(1.0) status_text.empty() progress_bar.empty() # 显示错误 for err in errors: st.error(f"❌ {err}") # 保存结果 if all_results: st.session_state["results"] = all_results[0]["results"] if len(all_results) == 1 else all_results st.session_state["results_df"] = pd.concat(all_dfs, ignore_index=True) if all_dfs else None st.session_state["smiles_used"] = smiles_list st.session_state["organ_used"] = selected_organ st.session_state["is_multi_smiles"] = is_multi_smiles st.success(f"✅ 优化完成!成功处理 {len(all_results)}/{len(smiles_list)} 条 SMILES") # 显示结果 if st.session_state["results"] is not None and st.session_state["results_df"] is not None: results = st.session_state["results"] df = st.session_state["results_df"] is_multi_smiles = st.session_state.get("is_multi_smiles", False) # 结果概览 if is_multi_smiles: # 多 SMILES 模式 col1, col2, col3 = st.columns(3) with col1: # 获取 target_organ(从第一个结果) first_result = results[0]["results"] if isinstance(results, list) else results target_organ = first_result["target_organ"] st.metric( "目标器官", ORGAN_LABELS.get(target_organ, target_organ).split(" ")[0], ) with col2: st.metric( "SMILES 数量", len(results) if isinstance(results, list) else 1, ) with col3: st.metric( "总配方数", len(df), ) else: # 单 SMILES 模式 col1, col2, col3 = st.columns(3) with col1: st.metric( "目标器官", ORGAN_LABELS.get(results["target_organ"], results["target_organ"]).split(" ")[0], ) with col2: best_score = results["formulations"][0]["target_biodist"] st.metric( "最优分布", f"{best_score*100:.2f}%", ) with col3: st.metric( "优选配方数", len(results["formulations"]), ) st.divider() # 结果表格 st.subheader("📊 优选配方列表") # 导出按钮行 col_export, col_spacer = st.columns([1, 4]) with col_export: smiles_used = st.session_state.get("smiles_used", "") if isinstance(smiles_used, list): smiles_used = ",".join(smiles_used) csv_content = create_export_csv( df, smiles_used, st.session_state.get("organ_used", ""), ) # 获取 target_organ if is_multi_smiles: target_organ = results[0]["results"]["target_organ"] if isinstance(results, list) else results["target_organ"] else: target_organ = results["target_organ"] st.download_button( label="📥 导出 CSV", data=csv_content, file_name=f"lnp_optimization_{target_organ}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", mime="text/csv", ) # 显示表格 st.dataframe( df, use_container_width=True, hide_index=True, height=600, ) # 详细信息 # with st.expander("🔍 查看最优配方详情"): # best = results["formulations"][0] # col1, col2 = st.columns(2) # with col1: # st.markdown("**配方参数**") # st.json({ # "阳离子脂质/mRNA 比例": best["cationic_lipid_to_mrna_ratio"], # "阳离子脂质 (mol%)": best["cationic_lipid_mol_ratio"], # "磷脂 (mol%)": best["phospholipid_mol_ratio"], # "胆固醇 (mol%)": best["cholesterol_mol_ratio"], # "PEG 脂质 (mol%)": best["peg_lipid_mol_ratio"], # "辅助脂质": best["helper_lipid"], # "给药途径": best["route"], # }) # with col2: # st.markdown("**各器官 Biodistribution 预测**") # biodist_df = pd.DataFrame([ # {"器官": ORGAN_LABELS.get(k, k), "Biodistribution": f"{v:.4f}"} # for k, v in best["all_biodist"].items() # ]) # st.dataframe(biodist_df, hide_index=True, use_container_width=True) else: # 欢迎信息 st.info("👈 请在左侧输入 SMILES 并选择目标器官,然后点击「开始配方优选」") # 使用说明 # with st.expander("📖 使用说明"): # st.markdown(""" # ### 如何使用 # 1. **输入 SMILES**: 在左侧输入框中输入阳离子脂质的 SMILES 字符串 # 2. **选择目标器官**: 选择您希望优化的器官靶向 # 3. **点击优选**: 系统将自动搜索最优配方组合 # 4. **查看结果**: 右侧将显示 Top-20 优选配方 # 5. **导出数据**: 点击导出按钮将结果保存为 CSV 文件 # ### 优化参数 # 系统会优化以下配方参数: # - **阳离子脂质/mRNA 比例**: 0.05 - 0.30 # - **阳离子脂质 mol 比例**: 0.05 - 0.80 # - **磷脂 mol 比例**: 0.00 - 0.80 # - **胆固醇 mol 比例**: 0.00 - 0.80 # - **PEG 脂质 mol 比例**: 0.00 - 0.05 # - **辅助脂质**: DOPE / DSPC / DOTAP # - **给药途径**: 静脉注射 / 肌肉注射 # ### 约束条件 # mol 比例之和 = 1 (阳离子脂质 + 磷脂 + 胆固醇 + PEG 脂质) # """) if __name__ == "__main__": main()