From 89b0d467eaead82a4efbf46583943e9618e44546 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Thu, 25 Jan 2024 19:04:43 +0800 Subject: [PATCH] =?UTF-8?q?=E9=9B=86=E6=88=90LOOM=E5=9C=A8=E7=BA=BFembeddi?= =?UTF-8?q?ng=E4=B8=9A=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/completion.py | 6 +- server/chat/file_chat.py | 12 ++-- .../repository/knowledge_base_repository.py | 10 +++- server/embeddings/core/embeddings_api.py | 6 +- server/knowledge_base/kb_api.py | 5 +- server/knowledge_base/kb_doc_api.py | 7 ++- server/knowledge_base/kb_service/base.py | 15 ++++- server/knowledge_base/kb_summary_api.py | 18 +++--- webui_pages/dialogue/dialogue.py | 4 +- webui_pages/knowledge_base/knowledge_base.py | 55 ++++++++++++++----- webui_pages/loom_view_client.py | 46 ++++++++++++++-- webui_pages/utils.py | 6 ++ 12 files changed, 143 insertions(+), 47 deletions(-) diff --git a/server/chat/completion.py b/server/chat/completion.py index 31eade96..559f10bb 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -14,9 +14,9 @@ from server.utils import get_prompt_template async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), stream: bool = Body(False, description="流式输出"), echo: bool = Body(False, description="除了输出之外,还回显输入"), - endpoint_host: str = Body(False, description="接入点地址"), - endpoint_host_key: str = Body(False, description="接入点key"), - endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index 2b67a86f..5fac6f0a 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -57,9 +57,9 @@ def _parse_files_in_thread( def upload_temp_docs( - endpoint_host: str = Body(False, description="接入点地址"), - endpoint_host_key: str = Body(False, description="接入点key"), - endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), files: List[UploadFile] = File(..., description="上传文件,支持多文件"), prev_id: str = Form(None, description="前知识库ID"), chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), @@ -110,9 +110,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), - endpoint_host: str = Body(False, description="接入点地址"), - endpoint_host_key: str = Body(False, description="接入点key"), - endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py index d6997e4c..d241328e 100644 --- a/server/db/repository/knowledge_base_repository.py +++ b/server/db/repository/knowledge_base_repository.py @@ -3,16 +3,22 @@ from server.db.session import with_session @with_session -def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): +def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model, endpoint_host: str = None, + endpoint_host_key: str = None, endpoint_host_proxy: str = None): # 创建知识库实例 kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() if not kb: - kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model) + kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model, + endpoint_host=endpoint_host, endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy) session.add(kb) else: # update kb with new vs_type and embed_model kb.kb_info = kb_info kb.vs_type = vs_type kb.embed_model = embed_model + kb.endpoint_host = endpoint_host + kb.endpoint_host_key = endpoint_host_key + kb.endpoint_host_proxy = endpoint_host_proxy return True diff --git a/server/embeddings/core/embeddings_api.py b/server/embeddings/core/embeddings_api.py index 9f81408c..5ac5d6b7 100644 --- a/server/embeddings/core/embeddings_api.py +++ b/server/embeddings/core/embeddings_api.py @@ -78,9 +78,9 @@ async def aembed_texts( def embed_texts_endpoint( texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), - endpoint_host: str = Body(False, description="接入点地址"), - endpoint_host_key: str = Body(False, description="接入点key"), - endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"), to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"), ) -> BaseResponse: diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 0d2fbce9..cab28d00 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -15,6 +15,9 @@ def list_kbs(): def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), vector_store_type: str = Body("faiss"), embed_model: str = Body(EMBEDDING_MODEL), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), ) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): @@ -28,7 +31,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) try: - kb.create_kb() + kb.create_kb(endpoint_host, endpoint_host_key, endpoint_host_proxy) except Exception as e: msg = f"创建知识库出错: {e}" logger.error(f'{e.__class__.__name__}: {msg}', diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 42799ef7..b80e4140 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -346,6 +346,9 @@ def recreate_vector_store( knowledge_base_name: str = Body(..., examples=["samples"]), allow_empty_kb: bool = Body(True), vs_type: str = Body(DEFAULT_VS_TYPE), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), embed_model: str = Body(EMBEDDING_MODEL), chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), @@ -366,7 +369,9 @@ def recreate_vector_store( else: if kb.exists(): kb.clear_vs() - kb.create_kb() + kb.create_kb(endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy) files = list_files_from_folder(knowledge_base_name) kb_files = [(file, knowledge_base_name) for file in files] i = 0 diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 8fb6d148..6e61b840 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -58,14 +58,23 @@ class KBService(ABC): ''' pass - def create_kb(self): + def create_kb(self, + endpoint_host: str = None, + endpoint_host_key: str = None, + endpoint_host_proxy: str = None): """ 创建知识库 """ if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) - self.do_create_kb() - status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) + + status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model, + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy) + + if status: + self.do_create_kb() return status def clear_vs(self): diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py index 00974d1c..674fde11 100644 --- a/server/knowledge_base/kb_summary_api.py +++ b/server/knowledge_base/kb_summary_api.py @@ -19,9 +19,9 @@ def recreate_summary_vector_store( vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), - endpoint_host: str = Body(False, description="接入点地址"), - endpoint_host_key: str = Body(False, description="接入点key"), - endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -112,9 +112,9 @@ def summary_file_to_vector_store( vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), - endpoint_host: str = Body(False, description="接入点地址"), - endpoint_host_key: str = Body(False, description="接入点key"), - endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -196,9 +196,9 @@ def summary_doc_ids_to_vector_store( vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), - endpoint_host: str = Body(False, description="接入点地址"), - endpoint_host_key: str = Body(False, description="接入点key"), - endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 9ba7e0c3..c9759115 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -4,7 +4,7 @@ import streamlit as st from streamlit_antd_components.utils import ParseItems from webui_pages.dialogue.utils import process_files -from webui_pages.loom_view_client import build_plugins_name, find_menu_items_by_index, set_llm_select, \ +from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \ get_select_model_endpoint from webui_pages.utils import * from streamlit_chatbox import * @@ -132,7 +132,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): conversation_id = st.session_state["conversation_ids"][conversation_name] with st.expander("模型选择"): - plugins_menu = build_plugins_name() + plugins_menu = build_providers_model_plugins_name() items, _ = ParseItems(plugins_menu).multi() diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 1443dfb7..57dd20fe 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -1,4 +1,8 @@ import streamlit as st +from streamlit_antd_components.utils import ParseItems + +from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \ + set_llm_select, set_embed_select, get_select_embed_endpoint from webui_pages.utils import * from st_aggrid import AgGrid, JsCode from st_aggrid.grid_options_builder import GridOptionsBuilder @@ -7,12 +11,16 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple from configs import (kbs_config, - EMBEDDING_MODEL, DEFAULT_VS_TYPE, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) + EMBEDDING_MODEL, DEFAULT_VS_TYPE, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, OPENAI_KEY, OPENAI_PROXY) from server.utils import list_embed_models + +import streamlit_antd_components as sac import os import time +# SENTENCE_SIZE = 100 + cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") @@ -96,24 +104,37 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): key="kb_info", ) - cols = st.columns(2) + col0, _ = st.columns([3, 1]) vs_types = list(kbs_config.keys()) - vs_type = cols[0].selectbox( + vs_type = col0.selectbox( "向量库类型", vs_types, index=vs_types.index(DEFAULT_VS_TYPE), key="vs_type", ) - embed_models = list_embed_models() + col1, _ = st.columns([3, 1]) + with col1: + col1.text("Embedding 模型") + plugins_menu = build_providers_embedding_plugins_name() - embed_model = cols[1].selectbox( - "Embedding 模型", - embed_models, - index=embed_models.index(EMBEDDING_MODEL), - key="embed_model", - ) + embed_models = list_embed_models() + menu_item_children = [] + for model in embed_models: + menu_item_children.append(sac.MenuItem(model, description=model)) + + plugins_menu.append(sac.MenuItem("本地Embedding 模型", icon='box-fill', children=menu_item_children)) + + items, _ = ParseItems(plugins_menu).multi() + + if len(plugins_menu) > 0: + + llm_model_index = sac.menu(plugins_menu, index=1, return_index=True, height=300, open_all=False) + plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index) + set_embed_select(plugins_info, llm_model_worker) + else: + st.info("没有可用的插件") submit_create_kb = st.form_submit_button( "新建", @@ -122,15 +143,23 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): ) if submit_create_kb: + + endpoint_host, select_embed_model_name = get_select_embed_endpoint() if not kb_name or not kb_name.strip(): st.error(f"知识库名称不能为空!") elif kb_name in kb_list: st.error(f"名为 {kb_name} 的知识库已经存在!") + elif select_embed_model_name is None: + st.error(f"请选择Embedding模型!") else: + ret = api.create_knowledge_base( knowledge_base_name=kb_name, vector_store_type=vs_type, - embed_model=embed_model, + embed_model=select_embed_model_name, + endpoint_host=endpoint_host, + endpoint_host_key=OPENAI_KEY, + endpoint_host_proxy=OPENAI_PROXY, ) st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name @@ -249,7 +278,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): # 将文件分词并加载到向量库中 if cols[1].button( "重新添加至向量库" if selected_rows and ( - pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", + pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): diff --git a/webui_pages/loom_view_client.py b/webui_pages/loom_view_client.py index 3777ed30..6c19b6de 100644 --- a/webui_pages/loom_view_client.py +++ b/webui_pages/loom_view_client.py @@ -3,6 +3,7 @@ from typing import Tuple, Any import streamlit as st from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient import logging + logger = logging.getLogger(__name__) client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False) @@ -45,11 +46,11 @@ def start_plugin(): st.toast("start_plugin " + start_plugins_name + ",starting.") result = client.launch_subscribe(start_plugins_name) - st.toast("start_plugin "+start_plugins_name + " ." + result.get("detail", "")) + st.toast("start_plugin " + start_plugins_name + " ." + result.get("detail", "")) time.sleep(3) result1 = client.launch_subscribe_start(start_plugins_name) - st.toast("start_plugin "+start_plugins_name + " ." + result1.get("detail", "")) + st.toast("start_plugin " + start_plugins_name + " ." + result1.get("detail", "")) time.sleep(2) update_store() @@ -103,7 +104,7 @@ def stop_worker(): update_store() -def build_plugins_name(): +def build_providers_model_plugins_name(): import streamlit_antd_components as sac if "run_plugins_list" not in st.session_state: return [] @@ -112,7 +113,25 @@ def build_plugins_name(): for key, value in st.session_state.list_running_models.items(): menu_item_children = [] for model in value: - menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"])) + if "model" in model["providers"]: + menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"])) + + menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children)) + + return menu_items + + +def build_providers_embedding_plugins_name(): + import streamlit_antd_components as sac + if "run_plugins_list" not in st.session_state: + return [] + # 按照模型构建sac.menu(菜单 + menu_items = [] + for key, value in st.session_state.list_running_models.items(): + menu_item_children = [] + for model in value: + if "embedding" in model["providers"]: + menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"])) menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children)) @@ -144,3 +163,22 @@ def get_select_model_endpoint() -> Tuple[str, str]: adapter_description = st.session_state.launch_subscribe_info[plugins_name] endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "") return endpoint_host, select_model_name + + +def set_embed_select(plugins_info, embed_model_worker): + st.session_state["select_embed_plugins_info"] = plugins_info + st.session_state["select_embed_model_worker"] = embed_model_worker + + +def get_select_embed_endpoint() -> Tuple[str, str]: + select_embed_plugins_info = st.session_state["select_embed_plugins_info"] + select_embed_model_worker = st.session_state["select_embed_model_worker"] + if select_embed_plugins_info is None or select_embed_model_worker is None: + raise ValueError("select_embed_plugins_info or select_embed_model_worker is None") + embed_plugins_name = st.session_state["select_embed_plugins_info"]['label'] + select_embed_model_name = st.session_state["select_embed_model_worker"]['label'] + endpoint_host = None + if embed_plugins_name in st.session_state.launch_subscribe_info: + adapter_description = st.session_state.launch_subscribe_info[embed_plugins_name] + endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "") + return endpoint_host, select_embed_model_name diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 09dd44d9..99716e14 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -382,6 +382,9 @@ class ApiRequest: knowledge_base_name: str, vector_store_type: str = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, + endpoint_host: str = None, + endpoint_host_key: str = None, + endpoint_host_proxy: str = None ): ''' 对应api.py/knowledge_base/create_knowledge_base接口 @@ -390,6 +393,9 @@ class ApiRequest: "knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "embed_model": embed_model, + "endpoint_host": endpoint_host, + "endpoint_host_key": endpoint_host_key, + "endpoint_host_proxy": endpoint_host_proxy, } response = self.post(