diff --git a/configs/model_config.py.example b/configs/model_config.py.example index e2cdda79..597d4f59 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -104,6 +104,22 @@ KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowled DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" +# 可选向量库类型及对应配置 +kbs_config = { + "faiss": { + }, + "milvus": { + "host": "127.0.0.1", + "port": "19530", + "user": "", + "password": "", + "secure": False, + }, + "pg": { + "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatglm", + } +} + # 默认向量库类型。可选:faiss, milvus, pg. DEFAULT_VS_TYPE = "faiss" @@ -152,21 +168,6 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" # 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG BING_SUBSCRIPTION_KEY = "" -kbs_config = { - "faiss": { - }, - "milvus": { - "host": "127.0.0.1", - "port": "19530", - "user": "", - "password": "", - "secure": False, - }, - "pg": { - "connection_uri": "postgresql://postgres:postgres@192.168.50.128:5432/langchain_chatgml", - } -} - # 是否开启中文标题加强,以及标题增强的相关配置 # 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 diff --git a/knowledge_base/samples/content/test.jpg b/knowledge_base/samples/content/test.jpg deleted file mode 100644 index 70c199b7..00000000 Binary files a/knowledge_base/samples/content/test.jpg and /dev/null differ diff --git a/knowledge_base/samples/content/test.pdf b/knowledge_base/samples/content/test.pdf deleted file mode 100644 index 3a137ad1..00000000 Binary files a/knowledge_base/samples/content/test.pdf and /dev/null differ diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 295196b4..dcd18cc4 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -272,7 +272,6 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]: "in_folder": True, "in_db": False, } - for doc in docs_in_db: doc_detail = get_file_detail(kb_name, doc) if doc_detail: diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 9f4dc602..6f1c392f 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -77,7 +77,7 @@ if __name__ == '__main__': from server.db.base import Base, engine Base.metadata.create_all(bind=engine) milvusService = MilvusKBService("test") - milvusService.add_doc(KnowledgeFile("test.pdf", "test")) - milvusService.delete_doc(KnowledgeFile("test.pdf", "test")) + milvusService.add_doc(KnowledgeFile("README.md", "test")) + milvusService.delete_doc(KnowledgeFile("README.md", "test")) milvusService.do_drop_kb() print(milvusService.search_docs("测试")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 7e4bfef5..82511bba 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -78,7 +78,7 @@ if __name__ == '__main__': Base.metadata.create_all(bind=engine) pGKBService = PGKBService("test") pGKBService.create_kb() - pGKBService.add_doc(KnowledgeFile("test.pdf", "test")) - pGKBService.delete_doc(KnowledgeFile("test.pdf", "test")) + pGKBService.add_doc(KnowledgeFile("README.md", "test")) + pGKBService.delete_doc(KnowledgeFile("README.md", "test")) pGKBService.drop_kb() print(pGKBService.search_docs("测试")) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 86bf499d..7a921285 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -36,8 +36,8 @@ def dialogue_page(api: ApiRequest): with st.sidebar: with st.expander("会话管理", True): - col_input, col_btn = st.columns([2, 1]) - new_chat_name = col_input.text_input( + col_input, col_btn = st.columns([1.5, 1]) + col_input.text_input( "新会话名称", placeholder="新会话名称", label_visibility="collapsed", @@ -50,7 +50,11 @@ def dialogue_page(api: ApiRequest): chat_box.use_chat_name(new_chat_name) st.session_state.new_chat_name = "" - col_btn.button("新建会话", on_click=on_btn_new_chat) + col_btn.button( + "新建会话", + on_click=on_btn_new_chat, + use_container_width=True, + ) chat_list = chat_box.get_chat_names() cur_chat_name = sac.buttons(chat_list, 0) diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 1b29cbe0..1faa3120 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -1,4 +1,3 @@ -from pydoc import doc import streamlit as st from webui_pages.utils import * from st_aggrid import AgGrid @@ -7,8 +6,9 @@ import pandas as pd from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details from typing import Literal, Dict, Tuple +from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE -SENTENCE_SIZE = 100 +# SENTENCE_SIZE = 100 def config_aggrid( @@ -22,155 +22,88 @@ def config_aggrid( for (col, header), kw in columns.items(): gb.configure_column(col, header, wrapHeaderText=True, **kw) gb.configure_selection( - selection_mode, - use_checkbox, + selection_mode=selection_mode, + use_checkbox=use_checkbox, # pre_selected_rows=st.session_state.get("selected_rows", [0]), ) return gb def knowledge_base_page(api: ApiRequest): - # api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) kb_list = get_kb_details() kb_names = [x["kb_name"] for x in kb_list] - cols = st.columns([3, 1, 1, 3]) - new_kb_name = cols[0].text_input( - "新知识库名称", - placeholder="新知识库名称,不支持中文命名", - label_visibility="collapsed", - key="new_kb_name", - ) - - if cols[1].button( - "新建", - disabled=not bool(new_kb_name), - use_container_width=True, - ) and new_kb_name: - if new_kb_name in kb_names: - st.error(f"名为 {new_kb_name} 的知识库已经存在!") - else: - ret = api.create_knowledge_base(new_kb_name) - st.toast(ret["msg"]) - st.experimental_rerun() - - if cols[2].button( - "删除", - disabled=not bool(new_kb_name), - use_container_width=True, - ) and new_kb_name: - if new_kb_name in kb_names: - ret = api.delete_knowledge_base(new_kb_name) - st.toast(ret["msg"]) - st.experimental_rerun() - else: - st.error(f"名为 {new_kb_name} 的知识库不存在!") - - selected_kb = cols[3].selectbox( + selected_kb = st.selectbox( "请选择知识库:", - kb_list, - format_func=lambda s: f"{s['kb_name']} ({s['vs_type']} @ {s['embed_model']})", - label_visibility="collapsed" + kb_list + ["新建知识库"], + format_func=lambda s: f"{s['kb_name']} ({s['vs_type']} @ {s['embed_model']})" if type(s) != str else s, ) - if selected_kb: + if selected_kb == "新建知识库": + with st.form("新建知识库"): + + kb_name = st.text_input( + "新建知识库名称", + placeholder="新知识库名称,不支持中文命名", + key="kb_name", + ) + + cols = st.columns(2) + + vs_types = list(kbs_config.keys()) + vs_type = cols[0].selectbox( + "向量库类型", + vs_types, + index=vs_types.index(DEFAULT_VS_TYPE), + key="vs_type", + ) + + embed_models = list(embedding_model_dict.keys()) + + embed_model = cols[1].selectbox( + "Embedding 模型", + embed_models, + index=embed_models.index(EMBEDDING_MODEL), + key="embed_model", + ) + + submit_create_kb = st.form_submit_button( + "新建", + # disabled=not bool(kb_name), + use_container_width=True, + ) + + if submit_create_kb: + if not kb_name or not kb_name.strip(): + st.error(f"知识库名称不能为空!") + elif kb_name in kb_list: + st.error(f"名为 {kb_name} 的知识库已经存在!") + else: + ret = api.create_knowledge_base( + knowledge_base_name=kb_name, + vector_store_type=vs_type, + embed_model=embed_model, + ) + st.toast(ret["msg"]) + # st.experimental_rerun() + + + elif selected_kb: kb = selected_kb["kb_name"] - # 知识库详情 - st.write(f"知识库 `{kb}` 详情:") - # st.info("请选择文件,点击按钮进行操作。") - doc_details = pd.DataFrame(get_kb_doc_details(kb)) - doc_details.drop(columns=["kb_name"], inplace=True) - doc_details = doc_details[[ - "No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db", - ]] - gb = config_aggrid( - doc_details, - { - ("file_name", "文档名称"): {}, - # ("file_ext", "文档类型"): {}, - # ("file_version", "文档版本"): {}, - ("document_loader", "文档加载器"): {}, - ("text_splitter", "分词器"): {}, - # ("create_time", "创建时间"): {}, - ("in_folder", "文件夹"): {}, - ("in_db", "数据库"): {}, - }, - "multiple", - ) - doc_grid = AgGrid( - doc_details, - gb.build(), - columns_auto_size_mode="FIT_CONTENTS", - theme="alpine", - custom_css={ - "#gridToolBar": {"display": "none"}, - }, - ) - - cols = st.columns(3) - selected_rows = doc_grid.get("selected_rows", []) - - cols = st.columns(4) - if selected_rows: - file_name = selected_rows[0]["file_name"] - file_path = get_file_path(kb, file_name) - with open(file_path, "rb") as fp: - cols[0].download_button( - "下载选中文档", - fp, - file_name=file_name, - use_container_width=True,) - else: - cols[0].download_button( - "下载选中文档", - "", - disabled=True, - use_container_width=True,) - - if cols[1].button( - "入库", - disabled=len(selected_rows) == 0, - use_container_width=True, - help="将文件分词并加载到向量库中", - ): - for row in selected_rows: - api.update_kb_doc(kb, row["file_name"]) - st.experimental_rerun() - - if cols[2].button( - "出库", - disabled=len(selected_rows) == 0, - use_container_width=True, - help="将文件从向量库中删除,但不删除文件本身。" - ): - for row in selected_rows: - api.delete_kb_doc(kb, row["file_name"]) - st.experimental_rerun() - - if cols[3].button( - "删除选中文档!", - type="primary", - use_container_width=True, - ): - for row in selected_rows: - ret = api.delete_kb_doc(kb, row["file_name"], True) - st.toast(ret["msg"]) - st.experimental_rerun() - - st.divider() + # 上传文件 # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) files = st.file_uploader("上传知识文件", - [i for ls in LOADER_DICT.values() for i in ls], - accept_multiple_files=True, - ) - cols = st.columns([3, 1]) - if cols[0].button( + [i for ls in LOADER_DICT.values() for i in ls], + accept_multiple_files=True, + ) + + if st.button( "添加文件到知识库", - help="请先上传文件,再点击添加", - use_container_width=True, + # help="请先上传文件,再点击添加", + # use_container_width=True, disabled=len(files) == 0, ): for f in files: @@ -181,10 +114,102 @@ def knowledge_base_page(api: ApiRequest): st.toast(ret["msg"], icon="❌") st.session_state.files = [] + st.divider() + + # 知识库详情 + # st.info("请选择文件,点击按钮进行操作。") + doc_details = pd.DataFrame(get_kb_doc_details(kb)) + if not len(doc_details): + st.info(f"知识库 `{kb}` 中暂无文件") + else: + st.write(f"知识库 `{kb}` 中已有文件:") + doc_details.drop(columns=["kb_name"], inplace=True) + doc_details = doc_details[[ + "No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db", + ]] + + gb = config_aggrid( + doc_details, + { + ("file_name", "文档名称"): {}, + # ("file_ext", "文档类型"): {}, + # ("file_version", "文档版本"): {}, + ("document_loader", "文档加载器"): {}, + ("text_splitter", "分词器"): {}, + # ("create_time", "创建时间"): {}, + ("in_folder", "源文件"): {}, + ("in_db", "向量库"): {}, + }, + "multiple", + ) + + doc_grid = AgGrid( + doc_details, + gb.build(), + columns_auto_size_mode="FIT_CONTENTS", + theme="alpine", + custom_css={ + "#gridToolBar": {"display": "none"}, + }, + ) + + selected_rows = doc_grid.get("selected_rows", []) + + cols = st.columns(4) + if selected_rows: + file_name = selected_rows[0]["file_name"] + file_path = get_file_path(kb, file_name) + with open(file_path, "rb") as fp: + cols[0].download_button( + "下载选中文档", + fp, + file_name=file_name, + use_container_width=True, ) + else: + cols[0].download_button( + "下载选中文档", + "", + disabled=True, + use_container_width=True, ) + + # 将文件分词并加载到向量库中 + if cols[1].button( + "添加至向量库", # "重新添加至向量库" + disabled=len(selected_rows) == 0, + use_container_width=True, + ): + for row in selected_rows: + api.update_kb_doc(kb, row["file_name"]) + st.experimental_rerun() + + # 将文件从向量库中删除,但不删除文件本身。 + if cols[2].button( + "从向量库删除", + disabled=len(selected_rows) == 0, + use_container_width=True, + ): + for row in selected_rows: + api.delete_kb_doc(kb, row["file_name"]) + st.experimental_rerun() + + if cols[3].button( + "删除选中文档!", + type="primary", + use_container_width=True, + ): + for row in selected_rows: + ret = api.delete_kb_doc(kb, row["file_name"], True) + st.toast(ret["msg"]) + st.experimental_rerun() + + st.divider() + + cols = st.columns(3) + # todo: freezed - if cols[1].button( - "重建知识库", - help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", + if cols[0].button( + "依据源文件重建向量库", + # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", use_container_width=True, type="primary", ): @@ -194,3 +219,11 @@ def knowledge_base_page(api: ApiRequest): print(d) empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") empty.write("重建完毕") + + if cols[2].button( + "删除知识库", + use_container_width=True, + ): + ret = api.delete_knowledge_base(kb) + st.toast(ret["detail"][0]["msg"]) + st.experimental_rerun()