新功能:

- 知识库管理中的add_docs/delete_docs/update_docs均支持批量操作,并利用多线程提高效率
- API的重建知识库接口支持多线程
- add_docs可提供参数控制上传文件后是否继续进行向量化
- add_docs/update_docs支持传入自定义docs(以json形式)。后续考虑区分完整或补充式自定义docs
- download_doc接口添加`preview`参数,支持下载或预览
- kb_service增加`save_vector_store`方法,便于保存向量库(仅FAISS,其它无操作)
- 将document_loader & text_splitter逻辑从KnowledgeFile中抽离出来,为后续对内存文件进行向量化做准备
- KowledgeFile支持docs & splitted_docs的缓存,方便在中间过程做一些自定义

其它:
- 将部分错误输出由print改为logger.error
This commit is contained in:
liunux4odoo 2023-09-08 08:55:12 +08:00
parent 93b133f9ac
commit 661a0e9d72
9 changed files with 818 additions and 223 deletions

View File

@ -15,8 +15,8 @@ from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
import httpx
@ -98,23 +98,23 @@ def create_app():
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc",
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库"
)(upload_doc)
summary="上传文件到知识库,并/或进行向量化"
)(upload_docs)
app.post("/knowledge_base/delete_doc",
app.post("/knowledge_base/delete_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_doc)
)(delete_docs)
app.post("/knowledge_base/update_doc",
app.post("/knowledge_base/update_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_doc)
)(update_docs)
app.get("/knowledge_base/download_doc",
tags=["Knowledge Base Management"],

View File

@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db
from configs.model_config import EMBEDDING_MODEL
from configs.model_config import EMBEDDING_MODEL, logger
from fastapi import Body
@ -30,8 +30,9 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
try:
kb.create_kb()
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
msg = f"创建知识库出错: {e}"
logger.error(msg)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
@ -55,7 +56,8 @@ async def delete_kb(
if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
msg = f"删除知识库时出现意外: {e}"
logger.error(msg)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@ -1,12 +1,17 @@
import os
import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
logger,)
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
files2docs_in_thread, KnowledgeFile)
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import Json
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from typing import List, Dict
from langchain.docstore.document import Document
@ -44,11 +49,83 @@ async def list_files(
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
def _save_files_in_thread(files: List[UploadFile],
knowledge_base_name: str,
override: bool):
'''
通过多线程将上传的文件保存到对应知识库目录内
生成器返回保存结果{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
'''
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
'''
保存单个文件
'''
try:
filename = file.filename
file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
file_content = file.file.read() # 读取上传文件的内容
if (os.path.isfile(file_path)
and not override
and os.path.getsize(file_path) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {filename} 已存在。"
logger.warn(file_status)
return dict(code=404, msg=file_status, data=data)
with open(file_path, "wb") as f:
f.write(file_content)
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
except Exception as e:
msg = f"{filename} 文件上传失败,报错信息为: {e}"
logger.error(msg)
return dict(code=500, msg=msg, data=data)
params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
for result in run_in_thread_pool(save_file, params=params):
yield result
# 似乎没有单独增加一个文件上传API接口的必要
# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件")):
# '''
# API接口上传文件。流式返回保存结果{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
# '''
# def generate(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)
# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream")
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件"),
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
# def save_files(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)
# def files_to_docs(files):
# for result in files2docs_in_thread(files):
# yield json.dumps(result, ensure_ascii=False)
async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
API接口上传文件/或向量化
'''
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -56,37 +133,36 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
file_content = await file.read() # 读取上传文件的内容
failed_files = {}
file_names = list(docs.keys())
try:
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
# 先将上传的文件保存到磁盘
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
filename = result["data"]["file_name"]
if result["code"] != 200:
failed_files[filename] = result["msg"]
if filename not in file_names:
file_names.append(filename)
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
# 对保存的文件进行向量化
if to_vector_store:
result = await update_docs(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
override_custom_docs=True,
docs=docs,
not_refresh_vs_cache=True,
)
failed_files.update(result.data["failed_files"])
if not not_refresh_vs_cache:
kb.save_vector_store()
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
try:
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]),
async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
@ -98,23 +174,31 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
failed_files = {}
for file_name in file_names:
if not kb.exist_doc(file_name):
failed_files[file_name] = f"未找到文件 {file_name}"
try:
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 文件删除失败,错误信息:{e}"
logger.error(msg)
failed_files[file_name] = msg
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
async def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
@ -127,22 +211,57 @@ async def update_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
failed_files = {}
kb_files = []
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
# 生成需要加载docs的文件列表
for file_name in file_names:
file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name)
# 如果该文件之前使用了自定义docs则根据参数决定略过或覆盖
if file_detail.get("custom_docs") and not override_custom_docs:
continue
if file_name not in docs:
try:
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
except Exception as e:
msg = f"加载文档 {file_name} 时出错:{e}"
logger.error(msg)
failed_files[file_name] = msg
# 从文件生成docs并进行向量化。
# 这里利用了KnowledgeFile的缓存功能在多线程中加载Document然后传给KnowledgeFile
for status, result in files2docs_in_thread(kb_files):
if status:
kb_name, file_name, new_docs = result
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
kb_file.splited_docs = new_docs
kb.update_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
failed_files[file_name] = error
# 将自定义的docs进行向量化
for file_name, v in docs.items():
try:
v = [x if isinstance(x, Document) else Document(**x) for x in v]
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 添加自定义docs时出错{e}"
logger.error(msg)
failed_files[file_name] = msg
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
async def download_doc(
knowledge_base_name: str = Query(..., examples=["samples"]),
file_name: str = Query(..., examples=["test.txt"]),
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
):
'''
下载知识库文档
@ -154,6 +273,11 @@ async def download_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
if preview:
content_disposition_type = "inline"
else:
content_disposition_type = None
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
@ -162,10 +286,13 @@ async def download_doc(
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
media_type="multipart/form-data",
content_disposition_type=content_disposition_type,
)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
logger.error(msg)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
@ -190,27 +317,30 @@ async def recreate_vector_store(
else:
kb.create_kb()
kb.clear_vs()
docs = list_files_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files]
i = 0
for status, result in files2docs_in_thread(kb_files):
if status:
kb_name, file_name, docs = result
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(docs)}): {doc}",
"total": len(docs),
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i,
"doc": doc,
"doc": file_name,
}, ensure_ascii=False)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
kb.add_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
"msg": msg,
})
i += 1
return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -49,6 +49,13 @@ class KBService(ABC):
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def save_vector_store(self, vector_store=None):
'''
保存向量库仅支持FAISS对于其它向量库该函数不做任何操作
减少FAISS向量库操作时的类型判断
'''
pass
def create_kb(self):
"""
创建知识库
@ -82,6 +89,8 @@ class KBService(ABC):
"""
if docs:
custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filepath)
else:
docs = kb_file.file2text()
custom_docs = False

View File

@ -5,7 +5,8 @@ from configs.model_config import (
KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
SCORE_THRESHOLD
SCORE_THRESHOLD,
logger,
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from functools import lru_cache
@ -28,7 +29,7 @@ def load_faiss_vector_store(
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.")
logger.info(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
@ -57,7 +58,7 @@ def refresh_vs_cache(kb_name: str):
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
class FaissKBService(KBService):
@ -128,7 +129,7 @@ class FaissKBService(KBService):
**kwargs):
vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
if len(ids) == 0:
return None

View File

@ -7,7 +7,8 @@ from configs.model_config import (
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE
ZH_TITLE_ENHANCE,
logger,
)
from functools import lru_cache
import importlib
@ -19,6 +20,7 @@ from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool
import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
@ -175,12 +177,74 @@ def get_LoaderClass(file_extension):
return LoaderClass
# 把一些向量化共用逻辑从KnowledgeFile抽取出来等langchain支持内存文件的时候可以将非磁盘文件向量化
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
'''
根据loader_name和文件路径或内容返回文档加载器
'''
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
logger.error(f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}")
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
elif loader_name == "CSVLoader":
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
elif loader_name == "CustomJSONLoader":
loader = DocumentLoader(file_path_or_content, text_content=False)
elif loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
else:
loader = DocumentLoader(file_path_or_content)
return loader
def make_text_splitter(
splitter_name: str = "SpacyTextSplitter",
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
):
'''
根据参数获取特定的分词器
'''
splitter_name = splitter_name or "SpacyTextSplitter"
text_splitter_module = importlib.import_module('langchain.text_splitter')
try:
TextSplitter = getattr(text_splitter_module, splitter_name)
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception as e:
logger.error(f"查找分词器 {splitter_name} 时出错:{e}")
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return text_splitter
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
'''
对应知识库目录中的文件必须是磁盘上存在的才能进行向量化等操作
'''
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1].lower()
@ -196,65 +260,11 @@ class KnowledgeFile:
def file2docs(self, refresh: bool=False):
if self.docs is None or refresh:
print(f"{self.document_loader_name} used for {self.filepath}")
try:
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e:
print(e)
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
elif self.document_loader_name == "CSVLoader":
loader = DocumentLoader(self.filepath, encoding="utf-8")
elif self.document_loader_name == "JSONLoader":
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
elif self.document_loader_name == "CustomJSONLoader":
loader = DocumentLoader(self.filepath, text_content=False)
elif self.document_loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(self.filepath, mode="elements")
elif self.document_loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(self.filepath, mode="elements")
else:
loader = DocumentLoader(self.filepath)
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath)
self.docs = loader.load()
return self.docs
def make_text_splitter(
self,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
):
try:
if self.text_splitter_name is None:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
self.text_splitter_name = "SpacyTextSplitter"
else:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
except Exception as e:
print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return text_splitter
def docs2texts(
self,
docs: List[Document] = None,
@ -265,10 +275,11 @@ class KnowledgeFile:
text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
if not docs:
return []
if self.ext not in [".csv"]:
if text_splitter is None:
text_splitter = self.make_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs = text_splitter.split_documents(docs)
print(f"文档切分示例:{docs[0]}")
@ -286,13 +297,18 @@ class KnowledgeFile:
text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
docs = self.file2docs()
self.splited_docs = self.docs2texts(docs=docs,
using_zh_title_enhance=using_zh_title_enhance,
refresh=refresh,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
text_splitter=text_splitter)
return self.splited_docs
def file_exist(self):
return os.path.isfile(self.filepath)
def get_mtime(self):
return os.path.getmtime(self.filepath)
@ -301,18 +317,21 @@ class KnowledgeFile:
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple形式为(filename, kb_name)
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
利用多线程批量将磁盘文件转化成langchain Document.
生成器返回值为{(kb_name, file_name): docs}
如果传入参数是Tuple形式为(filename, kb_name)
生成器返回值为 status, (kb_name, file_name, docs | error)
'''
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
return False, e
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
logger.error(msg)
return False, (file.kb_name, file.filename, msg)
kwargs_list = []
for i, file in enumerate(files):

View File

@ -0,0 +1,431 @@
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from configs.model_config import (
embedding_model_dict,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
logger,
)
from functools import lru_cache
import importlib
from text_splitter import zh_title_enhance
import langchain.document_loaders
import document_loaders
import unstructured.partition
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool
import io
import builtins
from datetime import datetime
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
# patch langchain.document_loaders和项目自定义document_loaders替换其中的open函数。
# 使其支持对str,bytes,io.StringIO,io.BytesIO进行向量化
def _new_open(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], *args, **kw):
if isinstance(content, (io.StringIO, io.BytesIO)):
return content
if isinstance(content, str):
if os.path.isfile(content):
return builtins.open(content, *args, **kw)
else:
return io.StringIO(content)
if isinstance(content, bytes):
return io.BytesIO(bytes)
if isinstance(content, Path):
return Path.open(*args, **kw)
return open(content, *args, **kw)
for module in [langchain.document_loaders, document_loaders]:
for k, v in module.__dict__.items():
if type(v) == type(langchain.document_loaders):
v.open = _new_open
# path unstructured 使其在处理非磁盘文件时不会出错
def _new_get_last_modified_date(filename: str) -> Union[str, None]:
try:
modify_date = datetime.fromtimestamp(os.path.getmtime(filename))
return modify_date.strftime("%Y-%m-%dT%H:%M:%S%z")
except:
return None
for k, v in unstructured.partition.__dict__.items():
if type(v) == type(unstructured.partition):
v.open = _new_open
v.get_last_modified_date = _new_get_last_modified_date
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
def list_kbs_from_folder():
return [f for f in os.listdir(KB_ROOT_PATH)
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
@lru_cache(1)
def load_embeddings(model: str, device: str):
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt',
'.ppt', '.pptx', '.tsv'], # '.xlsx'
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
langchain的JSONLoader需要jq在win上使用不便进行替代
'''
def __init__(
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
):
"""Initialize the JSONLoader.
Args:
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
content_key (str): The key to use to extract the content from the JSON if
results to a list of objects (dict).
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
object extracted by the jq_schema and the default metadata and returns
a dict of the updated metadata.
text_content (bool): Boolean flag to indicate whether the content is in
string format, default to True.
json_lines (bool): Boolean flag to indicate whether the input is in
JSON Lines format.
"""
self.file_path = Path(file_path).resolve()
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
self._json_lines = json_lines
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
def load(self) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
if self._json_lines:
with self.file_path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self._parse(line, docs)
else:
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
return docs
def _parse(self, content: str, docs: List[Document]) -> None:
"""Convert given content to documents."""
data = json.loads(content)
# Perform some validation
# This is not a perfect validation, but it should catch most cases
# and prevent the user from getting a cryptic error later on.
if self._content_key is not None:
self._validate_content_key(data)
for i, sample in enumerate(data, len(docs) + 1):
metadata = dict(
source=str(self.file_path),
seq_num=i,
)
text = self._get_text(sample=sample, metadata=metadata)
docs.append(Document(page_content=text, metadata=metadata))
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
return LoaderClass
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
'''
根据loader_name和文件路径或内容返回文档加载器
'''
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
logger.error(e)
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
elif loader_name == "CSVLoader":
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
elif loader_name == "CustomJSONLoader":
loader = DocumentLoader(file_path_or_content, text_content=False)
elif loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
else:
loader = DocumentLoader(file_path_or_content)
return loader
def make_text_splitter(
splitter_name: str = "SpacyTextSplitter",
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
):
'''
根据参数获取特定的分词器
'''
splitter_name = splitter_name or "SpacyTextSplitter"
text_splitter_module = importlib.import_module('langchain.text_splitter')
try:
TextSplitter = getattr(text_splitter_module, splitter_name)
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception as e:
logger.error(e)
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return text_splitter
def content_to_docs(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], ext: str = ".md") -> List[Document]:
'''
将磁盘文件文本字节内存文件等转化成Document
'''
if not ext.startswith("."):
ext = "." + ext
ext = ext.lower()
if ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {ext}")
loader_name = get_LoaderClass(ext)
loader = get_loader(loader_name=loader_name, file_path_or_content=content)
return loader.load()
def split_docs(
docs: List[Document],
splitter_name: str = "SpacyTextSplitter",
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Document]:
text_splitter = make_text_splitter(splitter_name=splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_documents(docs)
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
'''
对应知识库目录中的文件必须是磁盘上存在的才能进行向量化等操作
'''
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1].lower()
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = None
def file2docs(self, refresh: bool=False):
if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath)
self.docs = loader.load()
return self.docs
def docs2texts(
self,
docs: List[Document] = None,
using_zh_title_enhance=ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
if self.ext not in [".csv"]:
if text_splitter is None:
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs = text_splitter.split_documents(docs)
print(f"文档切分示例:{docs[0]}")
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
self.splited_docs = docs
return self.splited_docs
def file2text(
self,
using_zh_title_enhance=ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
refresh=refresh,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
text_splitter=text_splitter)
return self.splited_docs
def file_exist(self):
return os.path.isfile(self.filepath)
def get_mtime(self):
return os.path.getmtime(self.filepath)
def get_size(self):
return os.path.getsize(self.filepath)
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple形式为(filename, kb_name)
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
利用多线程批量将磁盘文件转化成langchain Document.
生成器返回值为 status, (kb_name, file_name, docs)
'''
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
return False, e
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
if isinstance(file, tuple) and len(file) >= 2:
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
kwargs = file
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs["file"] = file
kwargs_list.append(kwargs)
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
yield result
if __name__ == "__main__":
from pprint import pprint
# kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
# # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
# docs = kb_file.file2docs()
# pprint(docs[-1])
# docs = kb_file.file2text()
# pprint(docs[-1])
docs = content_to_docs("""
## this is a title
## another title
how are you
this a wonderful day.
""", "txt")
pprint(docs)
pprint(split_docs(docs, chunk_size=10, chunk_overlap=0))

View File

@ -7,17 +7,21 @@ root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
from pprint import pprint
api_base_url = api_address()
api = ApiRequest(api_base_url)
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
"test.txt": get_file_path("samples", "test.txt"),
}
@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
assert kb in data["data"]
def test_upload_doc(api="/knowledge_base/upload_doc"):
def test_upload_docs(api="/knowledge_base/upload_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n上传知识文件: {name}")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"文件 {name} 已存在。"
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 覆盖")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files(api="/knowledge_base/list_files"):
@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"):
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_doc(api="/knowledge_base/update_doc"):
def test_update_docs(api="/knowledge_base/update_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n更新知识文件 {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功更新文件 {name}"
print(f"\n更新知识文件")
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_doc(api="/knowledge_base/delete_doc"):
def test_delete_docs(api="/knowledge_base/delete_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n删除知识文件 {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"{name} 文件删除成功"
print(f"\n删除知识文件")
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
url = api_base_url + "/knowledge_base/search_docs"
query = "介绍一下langchain-chatchat项目"

View File

@ -21,9 +21,7 @@ from fastapi.responses import StreamingResponse
import contextlib
import json
from io import BytesIO
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async, set_httpx_timeout
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
from configs.model_config import NLTK_DATA_PATH
import nltk
@ -43,7 +41,7 @@ class ApiRequest:
'''
def __init__(
self,
base_url: str = "http://127.0.0.1:7861",
base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly
):
@ -78,7 +76,7 @@ class ApiRequest:
else:
return httpx.get(url, params=params, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when get {url}: {e}")
retry -= 1
async def aget(
@ -99,7 +97,7 @@ class ApiRequest:
else:
return await client.get(url, params=params, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when aget {url}: {e}")
retry -= 1
def post(
@ -121,7 +119,7 @@ class ApiRequest:
else:
return httpx.post(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when post {url}: {e}")
retry -= 1
async def apost(
@ -143,7 +141,7 @@ class ApiRequest:
else:
return await client.post(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when apost {url}: {e}")
retry -= 1
def delete(
@ -164,7 +162,7 @@ class ApiRequest:
else:
return httpx.delete(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when delete {url}: {e}")
retry -= 1
async def adelete(
@ -186,7 +184,7 @@ class ApiRequest:
else:
return await client.delete(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
logger.error(f"error when adelete {url}: {e}")
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
@ -205,7 +203,7 @@ class ApiRequest:
elif chunk.strip():
yield chunk
except Exception as e:
logger.error(e)
logger.error(f"error when run fastapi router: {e}")
def _httpx_stream2generator(
self,
@ -231,18 +229,18 @@ class ApiRequest:
print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
logger.error(msg)
logger.error(msg)
logger.error(e)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'"
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'。({e}"
logger.error(msg)
logger.error(e)
yield {"code": 500, "msg": msg}
except Exception as e:
logger.error(e)
yield {"code": 500, "msg": str(e)}
msg = f"API通信遇到错误{e}"
logger.error(msg)
yield {"code": 500, "msg": msg}
# 对话相关操作
@ -413,8 +411,9 @@ class ApiRequest:
try:
return response.json()
except Exception as e:
logger.error(e)
return {"code": 500, "msg": errorMsg or str(e)}
msg = "API未能返回正确的JSON。" + (errorMsg or str(e))
logger.error(msg)
return {"code": 500, "msg": msg}
def list_knowledge_bases(
self,
@ -510,12 +509,13 @@ class ApiRequest:
data = self._check_httpx_json_response(response)
return data.get("data", [])
def upload_kb_doc(
def upload_kb_docs(
self,
file: Union[str, Path, bytes],
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
filename: str = None,
override: bool = False,
to_vector_store: bool = True,
docs: List[Dict] = [],
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):