Merge branch 'dev' into dev_config_init

This commit is contained in:
glide-the 2024-06-10 22:00:22 +08:00
commit 72b1cab89a
37 changed files with 211 additions and 65 deletions

View File

@ -1,11 +1,13 @@
import os import os
from pathlib import Path from pathlib import Path
# chatchat 项目根目录 import sys
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) sys.path.append(str(Path(__file__).parent))
from _basic_config import config_workspace
# 用户数据根目录 # 用户数据根目录
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data") DATA_PATH = config_workspace.get_config().DATA_PATH
# 默认使用的知识库 # 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples" DEFAULT_KNOWLEDGE_BASE = "samples"

View File

@ -149,7 +149,7 @@ TOOL_CONFIG = {
"search_local_knowledgebase": { "search_local_knowledgebase": {
"use": False, "use": False,
"top_k": 3, "top_k": 3,
"score_threshold": 1, "score_threshold": 1.0,
"conclude_prompt": { "conclude_prompt": {
"with_result": "with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"' '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'

View File

@ -1,13 +1,14 @@
from urllib.parse import urlencode from urllib.parse import urlencode
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool, BaseToolOutput from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput
from chatchat.server.knowledge_base.kb_api import list_kbs from chatchat.server.knowledge_base.kb_api import list_kbs
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
from chatchat.configs import KB_INFO from chatchat.configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on "
"this knowledge use this tool. The 'database' should be one of the above [{key}].")
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples") template_knowledge = template.format(KB_info=KB_info_str, key="samples")
@ -49,7 +50,7 @@ def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
query: str = Field(description="Query for Knowledge Search"), query: str = Field(description="Query for Knowledge Search"),
): ):
'''''' """"""
tool_config = get_tool_config("search_local_knowledgebase") tool_config = get_tool_config("search_local_knowledgebase")
ret = search_knowledgebase(query=query, database=database, config=tool_config) ret = search_knowledgebase(query=query, database=database, config=tool_config)
return KBToolOutput(ret, database=database) return KBToolOutput(ret, database=database)

View File

@ -63,10 +63,10 @@ def upload_temp_docs(
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse: ) -> BaseResponse:
''' """
将文件保存到临时目录并进行向量化 将文件保存到临时目录并进行向量化
返回临时目录名称作为ID同时也是临时向量库的ID 返回临时目录名称作为ID同时也是临时向量库的ID
''' """
if prev_id is not None: if prev_id is not None:
memo_faiss_pool.pop(prev_id) memo_faiss_pool.pop(prev_id)
@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
docs = [x[0] for x in docs] docs = [x[0] for x in docs]
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: ## 如果没有找到相关文档使用Empty模板 if len(docs) == 0: # 如果没有找到相关文档使用Empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty") prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else: else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)

View File

@ -1,6 +1,6 @@
from typing import List from typing import List
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from chatchat.server.document_loaders.ocr import get_ocr from chatchat.server.file_rag.document_loaders.ocr import get_ocr
class RapidOCRLoader(UnstructuredFileLoader): class RapidOCRLoader(UnstructuredFileLoader):

View File

@ -4,7 +4,7 @@ import cv2
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from chatchat.configs import PDF_OCR_THRESHOLD from chatchat.configs import PDF_OCR_THRESHOLD
from chatchat.server.document_loaders.ocr import get_ocr from chatchat.server.file_rag.document_loaders.ocr import get_ocr
import tqdm import tqdm

View File

@ -0,0 +1,3 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService

View File

@ -0,0 +1,24 @@
from langchain.vectorstores import VectorStore
from abc import ABCMeta, abstractmethod
class BaseRetrieverService(metaclass=ABCMeta):
def __init__(self, **kwargs):
self.do_init(**kwargs)
@abstractmethod
def do_init(self, **kwargs):
pass
@abstractmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
pass
@abstractmethod
def get_relevant_documents(self, query: str):
pass

View File

@ -0,0 +1,48 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
class EnsembleRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
faiss_retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
# TODO: 换个不用torch的实现方式
# from cutword.cutword import Cutter
import jieba
# cutter = Cutter()
docs = list(vectorstore.docstore._dict.values())
bm25_retriever = BM25Retriever.from_documents(
docs,
preprocess_func=jieba.lcut_for_search,
)
bm25_retriever.k = top_k
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
return EnsembleRetrieverService(retriever=ensemble_retriever)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,33 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
class VectorstoreRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
return VectorstoreRetrieverService(retriever=retriever)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,13 @@
from chatchat.server.file_rag.retrievers import (
BaseRetrieverService,
VectorstoreRetrieverService,
EnsembleRetrieverService,
)
Retrivals = {
"vectorstore": VectorstoreRetrieverService,
"ensemble": EnsembleRetrieverService,
}
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
return Retrivals[type]

View File

@ -110,7 +110,8 @@ class KBFaissPool(_FaissPool):
if os.path.isfile(os.path.join(vs_path, "index.faiss")): if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = get_Embeddings(embed_model=embed_model) embeddings = get_Embeddings(embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,
allow_dangerous_deserialization=True)
elif create: elif create:
# create an empty vector store # create an empty vector store
if not os.path.exists(vs_path): if not os.path.exists(vs_path):

View File

@ -28,7 +28,7 @@ def search_docs(
description="知识库匹配相关度阈值取值范围在0-1之间" description="知识库匹配相关度阈值取值范围在0-1之间"
"SCORE越小相关度越高" "SCORE越小相关度越高"
"取到1相当于不筛选建议设置在0.5左右", "取到1相当于不筛选建议设置在0.5左右",
ge=0, le=1), ge=0.0, le=1.0),
file_name: str = Body("", description="文件名称,支持 sql 通配符"), file_name: str = Body("", description="文件名称,支持 sql 通配符"),
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
) -> List[Dict]: ) -> List[Dict]:
@ -37,7 +37,8 @@ def search_docs(
if kb is not None: if kb is not None:
if query: if query:
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
elif file_name or metadata: elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata) data = kb.list_docs(file_name=file_name, metadata=metadata)
for d in data: for d in data:

View File

@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
def _get_result_to_documents(get_result: GetResult) -> List[Document]: def _get_result_to_documents(get_result: GetResult) -> List[Document]:
@ -75,10 +76,13 @@ class ChromaKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[ def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
Tuple[Document, float]]: Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.collection,
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k) top_k=top_k,
return _results_to_docs_and_scores(query_result) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
doc_infos = [] doc_infos = []

View File

@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from elasticsearch import Elasticsearch, BadRequestError from elasticsearch import Elasticsearch, BadRequestError
from chatchat.configs import kbs_config, KB_ROOT_PATH from chatchat.configs import kbs_config, KB_ROOT_PATH
from chatchat.server.file_rag.utils import get_Retriever
import logging import logging
@ -107,8 +108,12 @@ class ESKBService(KBService):
def do_search(self, query:str, top_k: int, score_threshold: float): def do_search(self, query:str, top_k: int, score_threshold: float):
# 文本相似性检索 # 文本相似性检索
docs = self.db.similarity_search_with_score(query=query, retriever = get_Retriever("vectorstore").from_vectorstore(
k=top_k) self.db,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:

View File

@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from chatchat.server.utils import get_Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Tuple
from chatchat.server.file_rag.utils import get_Retriever
class FaissKBService(KBService): class FaissKBService(KBService):
@ -62,10 +62,13 @@ class FaissKBService(KBService):
top_k: int, top_k: int,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) retriever = get_Retriever("ensemble").from_vectorstore(
vs,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs return docs
def do_add_doc(self, def do_add_doc(self,

View File

@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
score_threshold_process score_threshold_process
from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.file_rag.utils import get_Retriever
class MilvusKBService(KBService): class MilvusKBService(KBService):
@ -67,10 +67,16 @@ class MilvusKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus() self._load_milvus()
embed_func = get_Embeddings(self.embed_model) # embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query) # embeddings = embed_func.embed_query(query)
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs) retriever = get_Retriever("vectorstore").from_vectorstore(
self.milvus,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:

View File

@ -15,6 +15,7 @@ import shutil
import sqlalchemy import sqlalchemy
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from chatchat.server.file_rag.utils import get_Retriever
class PGKBService(KBService): class PGKBService(KBService):
@ -60,10 +61,13 @@ class PGKBService(KBService):
shutil.rmtree(self.kb_path) shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.pg_vector,
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) top_k=top_k,
return score_threshold_process(score_threshold, top_k, docs) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs) ids = self.pg_vector.add_documents(docs)

View File

@ -1,5 +1,4 @@
from typing import List, Dict, Optional from typing import List, Dict
from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores import Zilliz from langchain.vectorstores import Zilliz
from chatchat.configs import kbs_config from chatchat.configs import kbs_config
@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
score_threshold_process score_threshold_process
from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
class ZillizKBService(KBService): class ZillizKBService(KBService):
@ -60,10 +60,13 @@ class ZillizKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz() self._load_zilliz()
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.zilliz,
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k) top_k=top_k,
return score_threshold_process(score_threshold, top_k, docs) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:

View File

@ -10,7 +10,7 @@ from chatchat.configs import (
TEXT_SPLITTER_NAME, TEXT_SPLITTER_NAME,
) )
import importlib import importlib
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain_community.document_loaders import langchain_community.document_loaders
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter

View File

@ -24,28 +24,28 @@ chat_box = ChatBox(
def save_session(): def save_session():
'''save session state to chat context''' """save session state to chat context"""
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"]) chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def restore_session(): def restore_session():
'''restore sesstion state from chat context''' """restore sesstion state from chat context"""
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"]) chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def rerun(): def rerun():
''' """
save chat context before rerun save chat context before rerun
''' """
save_session() save_session()
st.rerun() st.rerun()
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
''' """
返回消息历史 返回消息历史
content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要 content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要
''' """
def filter(msg): def filter(msg):
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]] content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
@st.cache_data @st.cache_data
def upload_temp_docs(files, _api: ApiRequest) -> str: def upload_temp_docs(files, _api: ApiRequest) -> str:
''' """
将文件上传到临时目录用于文件对话 将文件上传到临时目录用于文件对话
返回临时向量库ID 返回临时向量库ID
''' """
return _api.upload_temp_docs(files).get("data", {}).get("id") return _api.upload_temp_docs(files).get("data", {}).get("id")
@ -157,11 +157,13 @@ def dialogue_page(
tools = list_tools(api) tools = list_tools(api)
tool_names = ["None"] + list(tools) tool_names = ["None"] + list(tools)
if use_agent: if use_agent:
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools") # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# check_all=True, key="selected_tools")
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
key="selected_tools") key="selected_tools")
else: else:
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool") # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# key="selected_tool")
selected_tool = st.selectbox("选择工具", tool_names, selected_tool = st.selectbox("选择工具", tool_names,
format_func=lambda x: tools.get(x, {"title": "None"})["title"], format_func=lambda x: tools.get(x, {"title": "None"})["title"],
key="selected_tool") key="selected_tool")
@ -338,7 +340,7 @@ def dialogue_page(
elif d.status == AgentStatus.agent_finish: elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or "" text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n")) chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status == None: # not agent chat elif d.status is None: # not agent chat
if getattr(d, "is_ref", False): if getattr(d, "is_ref", False):
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
title="参考资料")) title="参考资料"))

View File

@ -1,9 +1,6 @@
[virtualenvs] [virtualenvs]
in-project = true in-project = true
[installer]
modern-installation = false
[plugins] [plugins]
[plugins.pypi_mirror] [plugins.pypi_mirror]
url = "https://pypi.tuna.tsinghua.edu.cn/simple" url = "https://pypi.tuna.tsinghua.edu.cn/simple"

View File

@ -29,6 +29,9 @@ unstructured = "~0.11.0"
python-magic-bin = {version = "*", platform = "win32"} python-magic-bin = {version = "*", platform = "win32"}
SQLAlchemy = "~2.0.25" SQLAlchemy = "~2.0.25"
faiss-cpu = "~1.7.4" faiss-cpu = "~1.7.4"
#cutword = "0.1.0"
jieba = "0.42.1"
rank_bm25 = "0.2.2"
# accelerate = "~0.24.1" # accelerate = "~0.24.1"
# spacy = "~3.7.2" # spacy = "~3.7.2"
PyMuPDF = "~1.23.16" PyMuPDF = "~1.23.16"
@ -264,4 +267,4 @@ dotenv = "dotenv:plugin"
[[tool.poetry.source]] [[tool.poetry.source]]
name = "tsinghua" name = "tsinghua"
url = "https://pypi.tuna.tsinghua.edu.cn/simple/" url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
priority = "default" priority = "primary"

View File

@ -1,9 +1,6 @@
[virtualenvs] [virtualenvs]
in-project = true in-project = true
[installer]
modern-installation = false
[plugins] [plugins]
[plugins.pypi_mirror] [plugins.pypi_mirror]
url = "https://pypi.tuna.tsinghua.edu.cn/simple" url = "https://pypi.tuna.tsinghua.edu.cn/simple"

View File

@ -1,9 +1,5 @@
[virtualenvs] [virtualenvs]
in-project = true in-project = true
[installer]
modern-installation = false
[plugins]
[plugins.pypi_mirror] [plugins.pypi_mirror]
url = "https://pypi.tuna.tsinghua.edu.cn/simple" url = "https://pypi.tuna.tsinghua.edu.cn/simple"

View File

@ -137,7 +137,7 @@ model_format = None
model_quant = None model_quant = None
if model_type == "LLM": if model_type == "LLM":
cur_family = xf_llm.LLMFamilyV1.model_validate(cur_reg) cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
cur_spec = None cur_spec = None
model_formats = [] model_formats = []
for spec in cur_reg["model_specs"]: for spec in cur_reg["model_specs"]: