diff --git a/server/api.py b/server/api.py index d5477d4d..92e595a6 100644 --- a/server/api.py +++ b/server/api.py @@ -2,6 +2,8 @@ import nltk import sys import os +from server.knowledge_base.kb_doc_api import update_kb_endpoint + sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs import VERSION, MEDIA_PATH @@ -168,6 +170,13 @@ def mount_knowledge_routes(app: FastAPI): response_model=BaseResponse, summary="更新知识库介绍" )(update_info) + + app.post("/knowledge_base/update_kb_endpoint", + tags=["Knowledge Base Management"], + response_model=BaseResponse, + summary="更新知识库在线api接入点配置" + )(update_kb_endpoint) + app.post("/knowledge_base/update_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index b80e4140..ceb4deb7 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -230,6 +230,26 @@ def update_info( return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info}) +def update_kb_endpoint( + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + endpoint_host: str = Body(None, description="接入点地址"), + endpoint_host_key: str = Body(None, description="接入点key"), + endpoint_host_proxy: str = Body(None, description="接入点代理地址"), +): + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + kb.update_kb_endpoint(endpoint_host, endpoint_host_key, endpoint_host_proxy) + + return BaseResponse(code=200, msg=f"知识库在线api接入点配置修改完成", + data={"endpoint_host": endpoint_host, + "endpoint_host_key": endpoint_host_key, + "endpoint_host_proxy": endpoint_host_proxy}) + + def update_docs( knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]), diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 6e61b840..336f4b85 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -6,7 +6,7 @@ from langchain.docstore.document import Document from server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, - load_kb_from_db, get_kb_detail, + load_kb_from_db, get_kb_detail, update_kb_endpoint_from_db, ) from server.db.repository.knowledge_file_repository import ( add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, @@ -144,6 +144,16 @@ class KBService(ABC): status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) return status + def update_kb_endpoint(self, + endpoint_host: str = None, + endpoint_host_key: str = None, + endpoint_host_proxy: str = None): + """ + 更新知识库在线api接入点配置 + """ + status = update_kb_endpoint_from_db(self.kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy) + return status + def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 使用content中的文件更新向量库 diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 57dd20fe..37a7b1aa 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -169,6 +169,9 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): elif selected_kb: kb = selected_kb st.session_state["selected_kb_info"] = kb_list[kb]['kb_info'] + st.session_state["kb_endpoint_host"] = kb_list[kb]['endpoint_host'] + st.session_state["kb_endpoint_host_key"] = kb_list[kb]['endpoint_host_key'] + st.session_state["kb_endpoint_host_proxy"] = kb_list[kb]['endpoint_host_proxy'] # 上传文件 files = st.file_uploader("上传知识文件:", [i for ls in LOADER_DICT.values() for i in ls], @@ -182,6 +185,37 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): st.session_state["selected_kb_info"] = kb_info api.update_kb_info(kb, kb_info) + if st.session_state["kb_endpoint_host"] is not None: + with st.expander( + "在线api接入点配置", + expanded=True, + ): + endpoint_host = st.text_input( + "接入点地址", + placeholder="接入点地址", + key="endpoint_host", + value=st.session_state["kb_endpoint_host"], + ) + endpoint_host_key = st.text_input( + "接入点key", + placeholder="接入点key", + key="endpoint_host_key", + value=st.session_state["kb_endpoint_host_key"], + ) + endpoint_host_proxy = st.text_input( + "接入点代理地址", + placeholder="接入点代理地址", + key="endpoint_host_proxy", + value=st.session_state["kb_endpoint_host_proxy"], + ) + if endpoint_host != st.session_state["kb_endpoint_host"] \ + or endpoint_host_key != st.session_state["kb_endpoint_host_key"] \ + or endpoint_host_proxy != st.session_state["kb_endpoint_host_proxy"]: + st.session_state["kb_endpoint_host"] = endpoint_host + st.session_state["kb_endpoint_host_key"] = endpoint_host_key + st.session_state["kb_endpoint_host_proxy"] = endpoint_host_proxy + api.update_kb_endpoint(kb, endpoint_host, endpoint_host_key, endpoint_host_proxy) + # with st.sidebar: with st.expander( "文件处理配置", @@ -278,7 +312,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): # 将文件分词并加载到向量库中 if cols[1].button( "重新添加至向量库" if selected_rows and ( - pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", + pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 99716e14..40d1377b 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -562,6 +562,26 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True) + def update_kb_endpoint(self, + knowledge_base_name, + endpoint_host: str = None, + endpoint_host_key: str = None, + endpoint_host_proxy: str = None): + ''' + 对应api.py/knowledge_base/update_info接口 + ''' + data = { + "knowledge_base_name": knowledge_base_name, + "endpoint_host": endpoint_host, + "endpoint_host_key": endpoint_host_key, + "endpoint_host_proxy": endpoint_host_proxy, + } + + response = self.post( + "/knowledge_base/update_kb_endpoint", + json=data, + ) + return self._get_response_value(response, as_json=True) def update_kb_docs( self, knowledge_base_name: str,