恢复 删除知识库选项

This commit is contained in:
zR 2024-01-22 12:46:15 +08:00
parent 80c26e4a24
commit b6d2bc71ce

View File

@ -7,15 +7,12 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple from typing import Literal, Dict, Tuple
from configs import (kbs_config, from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE, EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models, list_online_embed_models from server.utils import list_embed_models, list_online_embed_models
import os import os
import time import time
# SENTENCE_SIZE = 100
cell_renderer = JsCode("""function(params) {if(params.value==true){return ''}else{return '×'}}""") cell_renderer = JsCode("""function(params) {if(params.value==true){return ''}else{return '×'}}""")
@ -32,7 +29,7 @@ def config_aggrid(
gb.configure_selection( gb.configure_selection(
selection_mode=selection_mode, selection_mode=selection_mode,
use_checkbox=use_checkbox, use_checkbox=use_checkbox,
# pre_selected_rows=st.session_state.get("selected_rows", [0]), pre_selected_rows=st.session_state.get("selected_rows", [0]),
) )
gb.configure_pagination( gb.configure_pagination(
enabled=True, enabled=True,
@ -59,7 +56,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
try: try:
kb_list = {x["kb_name"]: x for x in get_kb_details()} kb_list = {x["kb_name"]: x for x in get_kb_details()}
except Exception as e: except Exception as e:
st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") st.error(
"获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
st.stop() st.stop()
kb_names = list(kb_list.keys()) kb_names = list(kb_list.keys())
@ -150,7 +148,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
[i for ls in LOADER_DICT.values() for i in ls], [i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True, accept_multiple_files=True,
) )
kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, key=None, kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None,
key=None,
help=None, on_change=None, args=None, kwargs=None) help=None, on_change=None, args=None, kwargs=None)
if kb_info != st.session_state["selected_kb_info"]: if kb_info != st.session_state["selected_kb_info"]:
@ -200,8 +199,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
doc_details = doc_details[[ doc_details = doc_details[[
"No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db", "No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db",
]] ]]
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") doc_details["in_folder"] = doc_details["in_folder"].replace(True, "").replace(False, "×")
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") doc_details["in_db"] = doc_details["in_db"].replace(True, "").replace(False, "×")
gb = config_aggrid( gb = config_aggrid(
doc_details, doc_details,
{ {
@ -252,7 +251,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
st.write() st.write()
# 将文件分词并加载到向量库中 # 将文件分词并加载到向量库中
if cols[1].button( if cols[1].button(
"重新添加至向量库" if selected_rows and (pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", "重新添加至向量库" if selected_rows and (
pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
disabled=not file_exists(kb, selected_rows)[0], disabled=not file_exists(kb, selected_rows)[0],
use_container_width=True, use_container_width=True,
): ):
@ -285,39 +285,39 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
st.divider() st.divider()
# cols = st.columns(3) cols = st.columns(3)
# if cols[0].button( if cols[0].button(
# "依据源文件重建向量库", "依据源文件重建向量库",
# # help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。", help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
# use_container_width=True, use_container_width=True,
# type="primary", type="primary",
# ): ):
# with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
# empty = st.empty() empty = st.empty()
# empty.progress(0.0, "") empty.progress(0.0, "")
# for d in api.recreate_vector_store(kb, for d in api.recreate_vector_store(kb,
# chunk_size=chunk_size, chunk_size=chunk_size,
# chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
# zh_title_enhance=zh_title_enhance): zh_title_enhance=zh_title_enhance):
# if msg := check_error_msg(d): if msg := check_error_msg(d):
# st.toast(msg) st.toast(msg)
# else: else:
# empty.progress(d["finished"] / d["total"], d["msg"]) empty.progress(d["finished"] / d["total"], d["msg"])
# st.rerun() st.rerun()
# if cols[2].button( if cols[2].button(
# "删除知识库", "删除知识库",
# use_container_width=True, use_container_width=True,
# ): ):
# ret = api.delete_knowledge_base(kb) ret = api.delete_knowledge_base(kb)
# st.toast(ret.get("msg", " ")) st.toast(ret.get("msg", " "))
# time.sleep(1) time.sleep(1)
# st.rerun() st.rerun()
# with st.sidebar: with st.sidebar:
# keyword = st.text_input("查询关键字") keyword = st.text_input("查询关键字")
# top_k = st.slider("匹配条数", 1, 100, 3) top_k = st.slider("匹配条数", 1, 100, 3)
st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。") st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。")
docs = [] docs = []
@ -325,11 +325,12 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
if selected_rows: if selected_rows:
file_name = selected_rows[0]["file_name"] file_name = selected_rows[0]["file_name"]
docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name) docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name)
data = [{"seq": i+1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"), data = [
"type": x["type"], {"seq": i + 1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"),
"metadata": json.dumps(x["metadata"], ensure_ascii=False), "type": x["type"],
"to_del": "", "metadata": json.dumps(x["metadata"], ensure_ascii=False),
} for i, x in enumerate(docs)] "to_del": "",
} for i, x in enumerate(docs)]
df = pd.DataFrame(data) df = pd.DataFrame(data)
gb = GridOptionsBuilder.from_dataframe(df) gb = GridOptionsBuilder.from_dataframe(df)
@ -343,22 +344,24 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
edit_docs = AgGrid(df, gb.build()) edit_docs = AgGrid(df, gb.build())
if st.button("保存更改"): if st.button("保存更改"):
# origin_docs = {x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in docs} origin_docs = {
x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in
docs}
changed_docs = [] changed_docs = []
for index, row in edit_docs.data.iterrows(): for index, row in edit_docs.data.iterrows():
# origin_doc = origin_docs[row["id"]] origin_doc = origin_docs[row["id"]]
# if row["page_content"] != origin_doc["page_content"]: if row["page_content"] != origin_doc["page_content"]:
if row["to_del"] not in ["Y", "y", 1]: if row["to_del"] not in ["Y", "y", 1]:
changed_docs.append({ changed_docs.append({
"page_content": row["page_content"], "page_content": row["page_content"],
"type": row["type"], "type": row["type"],
"metadata": json.loads(row["metadata"]), "metadata": json.loads(row["metadata"]),
}) })
if changed_docs: if changed_docs:
if api.update_kb_docs(knowledge_base_name=selected_kb, if api.update_kb_docs(knowledge_base_name=selected_kb,
file_names=[file_name], file_names=[file_name],
docs={file_name: changed_docs}): docs={file_name: changed_docs}):
st.toast("更新文档成功") st.toast("更新文档成功")
else: else:
st.toast("更新文档失败") st.toast("更新文档失败")