diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 31fc1512..76a97377 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -7,15 +7,12 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple from configs import (kbs_config, - EMBEDDING_MODEL, DEFAULT_VS_TYPE, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) + EMBEDDING_MODEL, DEFAULT_VS_TYPE, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) from server.utils import list_embed_models, list_online_embed_models import os import time - -# SENTENCE_SIZE = 100 - cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") @@ -32,7 +29,7 @@ def config_aggrid( gb.configure_selection( selection_mode=selection_mode, use_checkbox=use_checkbox, - # pre_selected_rows=st.session_state.get("selected_rows", [0]), + pre_selected_rows=st.session_state.get("selected_rows", [0]), ) gb.configure_pagination( enabled=True, @@ -59,7 +56,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): try: kb_list = {x["kb_name"]: x for x in get_kb_details()} except Exception as e: - st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") + st.error( + "获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") st.stop() kb_names = list(kb_list.keys()) @@ -150,7 +148,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): [i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True, ) - kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, key=None, + kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, + key=None, help=None, on_change=None, args=None, kwargs=None) if kb_info != st.session_state["selected_kb_info"]: @@ -200,8 +199,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): doc_details = doc_details[[ "No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db", ]] - # doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") - # doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") + doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") + doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") gb = config_aggrid( doc_details, { @@ -252,7 +251,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): st.write() # 将文件分词并加载到向量库中 if cols[1].button( - "重新添加至向量库" if selected_rows and (pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", + "重新添加至向量库" if selected_rows and ( + pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): @@ -285,39 +285,39 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): st.divider() - # cols = st.columns(3) + cols = st.columns(3) - # if cols[0].button( - # "依据源文件重建向量库", - # # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", - # use_container_width=True, - # type="primary", - # ): - # with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): - # empty = st.empty() - # empty.progress(0.0, "") - # for d in api.recreate_vector_store(kb, - # chunk_size=chunk_size, - # chunk_overlap=chunk_overlap, - # zh_title_enhance=zh_title_enhance): - # if msg := check_error_msg(d): - # st.toast(msg) - # else: - # empty.progress(d["finished"] / d["total"], d["msg"]) - # st.rerun() + if cols[0].button( + "依据源文件重建向量库", + help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", + use_container_width=True, + type="primary", + ): + with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): + empty = st.empty() + empty.progress(0.0, "") + for d in api.recreate_vector_store(kb, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): + if msg := check_error_msg(d): + st.toast(msg) + else: + empty.progress(d["finished"] / d["total"], d["msg"]) + st.rerun() - # if cols[2].button( - # "删除知识库", - # use_container_width=True, - # ): - # ret = api.delete_knowledge_base(kb) - # st.toast(ret.get("msg", " ")) - # time.sleep(1) - # st.rerun() + if cols[2].button( + "删除知识库", + use_container_width=True, + ): + ret = api.delete_knowledge_base(kb) + st.toast(ret.get("msg", " ")) + time.sleep(1) + st.rerun() - # with st.sidebar: - # keyword = st.text_input("查询关键字") - # top_k = st.slider("匹配条数", 1, 100, 3) + with st.sidebar: + keyword = st.text_input("查询关键字") + top_k = st.slider("匹配条数", 1, 100, 3) st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。") docs = [] @@ -325,11 +325,12 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): if selected_rows: file_name = selected_rows[0]["file_name"] docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name) - data = [{"seq": i+1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"), - "type": x["type"], - "metadata": json.dumps(x["metadata"], ensure_ascii=False), - "to_del": "", - } for i, x in enumerate(docs)] + data = [ + {"seq": i + 1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"), + "type": x["type"], + "metadata": json.dumps(x["metadata"], ensure_ascii=False), + "to_del": "", + } for i, x in enumerate(docs)] df = pd.DataFrame(data) gb = GridOptionsBuilder.from_dataframe(df) @@ -343,22 +344,24 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): edit_docs = AgGrid(df, gb.build()) if st.button("保存更改"): - # origin_docs = {x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in docs} + origin_docs = { + x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in + docs} changed_docs = [] for index, row in edit_docs.data.iterrows(): - # origin_doc = origin_docs[row["id"]] - # if row["page_content"] != origin_doc["page_content"]: - if row["to_del"] not in ["Y", "y", 1]: - changed_docs.append({ - "page_content": row["page_content"], - "type": row["type"], - "metadata": json.loads(row["metadata"]), - }) + origin_doc = origin_docs[row["id"]] + if row["page_content"] != origin_doc["page_content"]: + if row["to_del"] not in ["Y", "y", 1]: + changed_docs.append({ + "page_content": row["page_content"], + "type": row["type"], + "metadata": json.loads(row["metadata"]), + }) if changed_docs: if api.update_kb_docs(knowledge_base_name=selected_kb, - file_names=[file_name], - docs={file_name: changed_docs}): + file_names=[file_name], + docs={file_name: changed_docs}): st.toast("更新文档成功") else: st.toast("更新文档失败")