mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 07:53:29 +08:00
Feature(File RAG): add file_rag in chatchat-server, add ensemble retriever and vectorstore retriever.
This commit is contained in:
parent
b110fcd01b
commit
b56283eb01
@ -1,13 +1,14 @@
|
|||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from .tools_registry import regist_tool, BaseToolOutput
|
from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput
|
||||||
from chatchat.server.knowledge_base.kb_api import list_kbs
|
from chatchat.server.knowledge_base.kb_api import list_kbs
|
||||||
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
|
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
|
||||||
from chatchat.configs import KB_INFO
|
from chatchat.configs import KB_INFO
|
||||||
|
|
||||||
|
|
||||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on "
|
||||||
|
"this knowledge use this tool. The 'database' should be one of the above [{key}].")
|
||||||
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||||
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ def search_local_knowledgebase(
|
|||||||
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
|
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
|
||||||
query: str = Field(description="Query for Knowledge Search"),
|
query: str = Field(description="Query for Knowledge Search"),
|
||||||
):
|
):
|
||||||
''''''
|
""""""
|
||||||
tool_config = get_tool_config("search_local_knowledgebase")
|
tool_config = get_tool_config("search_local_knowledgebase")
|
||||||
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
||||||
return KBToolOutput(ret, database=database)
|
return KBToolOutput(ret, database=database)
|
||||||
|
|||||||
@ -63,10 +63,10 @@ def upload_temp_docs(
|
|||||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
'''
|
"""
|
||||||
将文件保存到临时目录,并进行向量化。
|
将文件保存到临时目录,并进行向量化。
|
||||||
返回临时目录名称作为ID,同时也是临时向量库的ID。
|
返回临时目录名称作为ID,同时也是临时向量库的ID。
|
||||||
'''
|
"""
|
||||||
if prev_id is not None:
|
if prev_id is not None:
|
||||||
memo_faiss_pool.pop(prev_id)
|
memo_faiss_pool.pop(prev_id)
|
||||||
|
|
||||||
@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||||||
docs = [x[0] for x in docs]
|
docs = [x[0] for x in docs]
|
||||||
|
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
|
if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||||
else:
|
else:
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
from chatchat.server.document_loaders.ocr import get_ocr
|
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
|
||||||
|
|
||||||
|
|
||||||
class RapidOCRLoader(UnstructuredFileLoader):
|
class RapidOCRLoader(UnstructuredFileLoader):
|
||||||
@ -4,7 +4,7 @@ import cv2
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from chatchat.configs import PDF_OCR_THRESHOLD
|
from chatchat.configs import PDF_OCR_THRESHOLD
|
||||||
from chatchat.server.document_loaders.ocr import get_ocr
|
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
|
||||||
|
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService
|
||||||
@ -0,0 +1,24 @@
|
|||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRetrieverService(metaclass=ABCMeta):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.do_init(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def do_init(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_relevant_documents(self, query: str):
|
||||||
|
pass
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain_community.retrievers import BM25Retriever
|
||||||
|
from langchain.retrievers import EnsembleRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleRetrieverService(BaseRetrieverService):
|
||||||
|
def do_init(
|
||||||
|
self,
|
||||||
|
retriever: BaseRetriever = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
):
|
||||||
|
self.vs = None
|
||||||
|
self.top_k = top_k
|
||||||
|
self.retriever = retriever
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
faiss_retriever = vectorstore.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"k": top_k
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# TODO: 换个不用torch的实现方式
|
||||||
|
from cutword.cutword import Cutter
|
||||||
|
cutter = Cutter()
|
||||||
|
docs = list(vectorstore.docstore._dict.values())
|
||||||
|
bm25_retriever = BM25Retriever.from_documents(
|
||||||
|
docs,
|
||||||
|
preprocess_func=cutter.cutword
|
||||||
|
)
|
||||||
|
bm25_retriever.k = top_k
|
||||||
|
ensemble_retriever = EnsembleRetriever(
|
||||||
|
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
|
||||||
|
)
|
||||||
|
return EnsembleRetrieverService(retriever=ensemble_retriever)
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str):
|
||||||
|
return self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class VectorstoreRetrieverService(BaseRetrieverService):
|
||||||
|
def do_init(
|
||||||
|
self,
|
||||||
|
retriever: BaseRetriever = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
):
|
||||||
|
self.vs = None
|
||||||
|
self.top_k = top_k
|
||||||
|
self.retriever = retriever
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
retriever = vectorstore.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"k": top_k
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return VectorstoreRetrieverService(retriever=retriever)
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str):
|
||||||
|
return self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||||
13
libs/chatchat-server/chatchat/server/file_rag/utils.py
Normal file
13
libs/chatchat-server/chatchat/server/file_rag/utils.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers import (
|
||||||
|
BaseRetrieverService,
|
||||||
|
VectorstoreRetrieverService,
|
||||||
|
EnsembleRetrieverService,
|
||||||
|
)
|
||||||
|
|
||||||
|
Retrivals = {
|
||||||
|
"vectorstore": VectorstoreRetrieverService,
|
||||||
|
"ensemble": EnsembleRetrieverService,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
|
||||||
|
return Retrivals[type]
|
||||||
@ -37,7 +37,8 @@ def search_docs(
|
|||||||
if kb is not None:
|
if kb is not None:
|
||||||
if query:
|
if query:
|
||||||
docs = kb.search_docs(query, top_k, score_threshold)
|
docs = kb.search_docs(query, top_k, score_threshold)
|
||||||
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||||
|
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
|
||||||
elif file_name or metadata:
|
elif file_name or metadata:
|
||||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||||
for d in data:
|
for d in data:
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD
|
|||||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.utils import get_Embeddings
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
||||||
@ -75,10 +76,13 @@ class ChromaKBService(KBService):
|
|||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
|
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
|
||||||
Tuple[Document, float]]:
|
Tuple[Document, float]]:
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
embeddings = embed_func.embed_query(query)
|
self.collection,
|
||||||
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
|
top_k=top_k,
|
||||||
return _results_to_docs_and_scores(query_result)
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
doc_infos = []
|
doc_infos = []
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile
|
|||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.utils import get_Embeddings
|
||||||
from elasticsearch import Elasticsearch, BadRequestError
|
from elasticsearch import Elasticsearch, BadRequestError
|
||||||
from chatchat.configs import kbs_config, KB_ROOT_PATH
|
from chatchat.configs import kbs_config, KB_ROOT_PATH
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -107,8 +108,12 @@ class ESKBService(KBService):
|
|||||||
|
|
||||||
def do_search(self, query:str, top_k: int, score_threshold: float):
|
def do_search(self, query:str, top_k: int, score_threshold: float):
|
||||||
# 文本相似性检索
|
# 文本相似性检索
|
||||||
docs = self.db.similarity_search_with_score(query=query,
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
k=top_k)
|
self.db,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
|
|||||||
@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD
|
|||||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||||
from chatchat.server.utils import get_Embeddings
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Dict, Tuple
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
@ -62,10 +62,13 @@ class FaissKBService(KBService):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
|
||||||
embeddings = embed_func.embed_query(query)
|
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
retriever = get_Retriever("ensemble").from_vectorstore(
|
||||||
|
vs,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f
|
|||||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
|
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
|
||||||
score_threshold_process
|
score_threshold_process
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class MilvusKBService(KBService):
|
class MilvusKBService(KBService):
|
||||||
@ -67,10 +67,16 @@ class MilvusKBService(KBService):
|
|||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_milvus()
|
self._load_milvus()
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
# embed_func = get_Embeddings(self.embed_model)
|
||||||
embeddings = embed_func.embed_query(query)
|
# embeddings = embed_func.embed_query(query)
|
||||||
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
# docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||||
return score_threshold_process(score_threshold, top_k, docs)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
|
self.milvus,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import shutil
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class PGKBService(KBService):
|
class PGKBService(KBService):
|
||||||
@ -60,10 +61,13 @@ class PGKBService(KBService):
|
|||||||
shutil.rmtree(self.kb_path)
|
shutil.rmtree(self.kb_path)
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
embeddings = embed_func.embed_query(query)
|
self.pg_vector,
|
||||||
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
top_k=top_k,
|
||||||
return score_threshold_process(score_threshold, top_k, docs)
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
ids = self.pg_vector.add_documents(docs)
|
ids = self.pg_vector.add_documents(docs)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from typing import List, Dict, Optional
|
from typing import List, Dict
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores import Zilliz
|
from langchain.vectorstores import Zilliz
|
||||||
from chatchat.configs import kbs_config
|
from chatchat.configs import kbs_config
|
||||||
@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
|
|||||||
score_threshold_process
|
score_threshold_process
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.utils import get_Embeddings
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class ZillizKBService(KBService):
|
class ZillizKBService(KBService):
|
||||||
@ -60,10 +60,13 @@ class ZillizKBService(KBService):
|
|||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_zilliz()
|
self._load_zilliz()
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
embeddings = embed_func.embed_query(query)
|
self.zilliz,
|
||||||
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
|
top_k=top_k,
|
||||||
return score_threshold_process(score_threshold, top_k, docs)
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from chatchat.configs import (
|
|||||||
TEXT_SPLITTER_NAME,
|
TEXT_SPLITTER_NAME,
|
||||||
)
|
)
|
||||||
import importlib
|
import importlib
|
||||||
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance
|
from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||||
import langchain_community.document_loaders
|
import langchain_community.document_loaders
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
||||||
|
|||||||
@ -24,28 +24,28 @@ chat_box = ChatBox(
|
|||||||
|
|
||||||
|
|
||||||
def save_session():
|
def save_session():
|
||||||
'''save session state to chat context'''
|
"""save session state to chat context"""
|
||||||
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||||
|
|
||||||
|
|
||||||
def restore_session():
|
def restore_session():
|
||||||
'''restore sesstion state from chat context'''
|
"""restore sesstion state from chat context"""
|
||||||
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||||
|
|
||||||
|
|
||||||
def rerun():
|
def rerun():
|
||||||
'''
|
"""
|
||||||
save chat context before rerun
|
save chat context before rerun
|
||||||
'''
|
"""
|
||||||
save_session()
|
save_session()
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||||
'''
|
"""
|
||||||
返回消息历史。
|
返回消息历史。
|
||||||
content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要
|
content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def filter(msg):
|
def filter(msg):
|
||||||
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
|
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
|
||||||
@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
|||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
def upload_temp_docs(files, _api: ApiRequest) -> str:
|
def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||||
'''
|
"""
|
||||||
将文件上传到临时目录,用于文件对话
|
将文件上传到临时目录,用于文件对话
|
||||||
返回临时向量库ID
|
返回临时向量库ID
|
||||||
'''
|
"""
|
||||||
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
||||||
|
|
||||||
|
|
||||||
@ -157,11 +157,13 @@ def dialogue_page(
|
|||||||
tools = list_tools(api)
|
tools = list_tools(api)
|
||||||
tool_names = ["None"] + list(tools)
|
tool_names = ["None"] + list(tools)
|
||||||
if use_agent:
|
if use_agent:
|
||||||
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools")
|
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
|
||||||
|
# check_all=True, key="selected_tools")
|
||||||
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
|
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
|
||||||
key="selected_tools")
|
key="selected_tools")
|
||||||
else:
|
else:
|
||||||
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool")
|
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
|
||||||
|
# key="selected_tool")
|
||||||
selected_tool = st.selectbox("选择工具", tool_names,
|
selected_tool = st.selectbox("选择工具", tool_names,
|
||||||
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
|
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
|
||||||
key="selected_tool")
|
key="selected_tool")
|
||||||
@ -338,7 +340,7 @@ def dialogue_page(
|
|||||||
elif d.status == AgentStatus.agent_finish:
|
elif d.status == AgentStatus.agent_finish:
|
||||||
text = d.choices[0].delta.content or ""
|
text = d.choices[0].delta.content or ""
|
||||||
chat_box.update_msg(text.replace("\n", "\n\n"))
|
chat_box.update_msg(text.replace("\n", "\n\n"))
|
||||||
elif d.status == None: # not agent chat
|
elif d.status is None: # not agent chat
|
||||||
if getattr(d, "is_ref", False):
|
if getattr(d, "is_ref", False):
|
||||||
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
|
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
|
||||||
title="参考资料"))
|
title="参考资料"))
|
||||||
|
|||||||
@ -29,6 +29,8 @@ unstructured = "~0.11.0"
|
|||||||
python-magic-bin = {version = "*", platform = "win32"}
|
python-magic-bin = {version = "*", platform = "win32"}
|
||||||
SQLAlchemy = "~2.0.25"
|
SQLAlchemy = "~2.0.25"
|
||||||
faiss-cpu = "~1.7.4"
|
faiss-cpu = "~1.7.4"
|
||||||
|
cutword = "0.1.0"
|
||||||
|
rank_bm25 = "0.2.2"
|
||||||
# accelerate = "~0.24.1"
|
# accelerate = "~0.24.1"
|
||||||
# spacy = "~3.7.2"
|
# spacy = "~3.7.2"
|
||||||
PyMuPDF = "~1.23.16"
|
PyMuPDF = "~1.23.16"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user