From d6620eb6284f6a4daa692cbe19953da7cc69595f Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Thu, 25 Jan 2024 20:12:18 +0800 Subject: [PATCH] =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E5=9C=A8=E7=BA=BFap?= =?UTF-8?q?i=E6=8E=A5=E5=85=A5=E7=82=B9=E9=85=8D=E7=BD=AE=E5=9C=A8?= =?UTF-8?q?=E7=BA=BFapi=E6=8E=A5=E5=85=A5=E7=82=B9=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/api.py | 9 +++++ server/knowledge_base/kb_doc_api.py | 20 +++++++++++ server/knowledge_base/kb_service/base.py | 12 ++++++- webui_pages/knowledge_base/knowledge_base.py | 36 +++++++++++++++++++- webui_pages/utils.py | 20 +++++++++++ 5 files changed, 95 insertions(+), 2 deletions(-) diff --git a/server/api.py b/server/api.py index d5477d4d..92e595a6 100644 --- a/server/api.py +++ b/server/api.py @@ -2,6 +2,8 @@ import nltk import sys import os +from server.knowledge_base.kb_doc_api import update_kb_endpoint + sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs import VERSION, MEDIA_PATH @@ -168,6 +170,13 @@ def mount_knowledge_routes(app: FastAPI): response_model=BaseResponse, summary="更新知识库介绍" )(update_info) + + app.post("/knowledge_base/update_kb_endpoint", + tags=["Knowledge Base Management"], + response_model=BaseResponse, + summary="更新知识库在线api接入点配置" + )(update_kb_endpoint) + app.post("/knowledge_base/update_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index b80e4140..ceb4deb7 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -230,6 +230,26 @@ def update_info( return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info}) +def update_kb_endpoint( + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), +): + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + kb.update_kb_endpoint(endpoint_host, endpoint_host_key, endpoint_host_proxy) + + return BaseResponse(code=200, msg=f"知识库在线api接入点配置修改完成", + data={"endpoint_host": endpoint_host, + "endpoint_host_key": endpoint_host_key, + "endpoint_host_proxy": endpoint_host_proxy}) + + def update_docs( knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]), diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 6e61b840..336f4b85 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -6,7 +6,7 @@ from langchain.docstore.document import Document from server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, - load_kb_from_db, get_kb_detail, + load_kb_from_db, get_kb_detail, update_kb_endpoint_from_db, ) from server.db.repository.knowledge_file_repository import ( add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, @@ -144,6 +144,16 @@ class KBService(ABC): status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) return status + def update_kb_endpoint(self, + endpoint_host: str = None, + endpoint_host_key: str = None, + endpoint_host_proxy: str = None): + """ + 更新知识库在线api接入点配置 + """ + status = update_kb_endpoint_from_db(self.kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy) + return status + def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 使用content中的文件更新向量库 diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 57dd20fe..37a7b1aa 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -169,6 +169,9 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): elif selected_kb: kb = selected_kb st.session_state["selected_kb_info"] = kb_list[kb]['kb_info'] + st.session_state["kb_endpoint_host"] = kb_list[kb]['endpoint_host'] + st.session_state["kb_endpoint_host_key"] = kb_list[kb]['endpoint_host_key'] + st.session_state["kb_endpoint_host_proxy"] = kb_list[kb]['endpoint_host_proxy'] # 上传文件 files = st.file_uploader("上传知识文件:", [i for ls in LOADER_DICT.values() for i in ls], @@ -182,6 +185,37 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): st.session_state["selected_kb_info"] = kb_info api.update_kb_info(kb, kb_info) + if st.session_state["kb_endpoint_host"] is not None: + with st.expander( + "在线api接入点配置", + expanded=True, + ): + endpoint_host = st.text_input( + "接入点地址", + placeholder="接入点地址", + key="endpoint_host", + value=st.session_state["kb_endpoint_host"], + ) + endpoint_host_key = st.text_input( + "接入点key", + placeholder="接入点key", + key="endpoint_host_key", + value=st.session_state["kb_endpoint_host_key"], + ) + endpoint_host_proxy = st.text_input( + "接入点代理地址", + placeholder="接入点代理地址", + key="endpoint_host_proxy", + value=st.session_state["kb_endpoint_host_proxy"], + ) + if endpoint_host != st.session_state["kb_endpoint_host"] \ + or endpoint_host_key != st.session_state["kb_endpoint_host_key"] \ + or endpoint_host_proxy != st.session_state["kb_endpoint_host_proxy"]: + st.session_state["kb_endpoint_host"] = endpoint_host + st.session_state["kb_endpoint_host_key"] = endpoint_host_key + st.session_state["kb_endpoint_host_proxy"] = endpoint_host_proxy + api.update_kb_endpoint(kb, endpoint_host, endpoint_host_key, endpoint_host_proxy) + # with st.sidebar: with st.expander( "文件处理配置", @@ -278,7 +312,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): # 将文件分词并加载到向量库中 if cols[1].button( "重新添加至向量库" if selected_rows and ( - pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", + pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 99716e14..40d1377b 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -562,6 +562,26 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True) + def update_kb_endpoint(self, + knowledge_base_name, + endpoint_host: str = None, + endpoint_host_key: str = None, + endpoint_host_proxy: str = None): + ''' + 对应api.py/knowledge_base/update_info接口 + ''' + data = { + "knowledge_base_name": knowledge_base_name, + "endpoint_host": endpoint_host, + "endpoint_host_key": endpoint_host_key, + "endpoint_host_proxy": endpoint_host_proxy, + } + + response = self.post( + "/knowledge_base/update_kb_endpoint", + json=data, + ) + return self._get_response_value(response, as_json=True) def update_kb_docs( self, knowledge_base_name: str,