mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
新功能:
- 知识库管理中的add_docs/delete_docs/update_docs均支持批量操作,并利用多线程提高效率 - API的重建知识库接口支持多线程 - add_docs可提供参数控制上传文件后是否继续进行向量化 - add_docs/update_docs支持传入自定义docs(以json形式)。后续考虑区分完整或补充式自定义docs - download_doc接口添加`preview`参数,支持下载或预览 - kb_service增加`save_vector_store`方法,便于保存向量库(仅FAISS,其它无操作) - 将document_loader & text_splitter逻辑从KnowledgeFile中抽离出来,为后续对内存文件进行向量化做准备 - KowledgeFile支持docs & splitted_docs的缓存,方便在中间过程做一些自定义 其它: - 将部分错误输出由print改为logger.error
This commit is contained in:
parent
93b133f9ac
commit
661a0e9d72
@ -15,8 +15,8 @@ from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
|
||||
update_doc, download_doc, recreate_vector_store,
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
|
||||
import httpx
|
||||
@ -98,23 +98,23 @@ def create_app():
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
app.post("/knowledge_base/upload_doc",
|
||||
app.post("/knowledge_base/upload_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="上传文件到知识库"
|
||||
)(upload_doc)
|
||||
summary="上传文件到知识库,并/或进行向量化"
|
||||
)(upload_docs)
|
||||
|
||||
app.post("/knowledge_base/delete_doc",
|
||||
app.post("/knowledge_base/delete_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内指定文件"
|
||||
)(delete_doc)
|
||||
)(delete_docs)
|
||||
|
||||
app.post("/knowledge_base/update_doc",
|
||||
app.post("/knowledge_base/update_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="更新现有文件到知识库"
|
||||
)(update_doc)
|
||||
)(update_docs)
|
||||
|
||||
app.get("/knowledge_base/download_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
|
||||
@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||
from configs.model_config import EMBEDDING_MODEL
|
||||
from configs.model_config import EMBEDDING_MODEL, logger
|
||||
from fastapi import Body
|
||||
|
||||
|
||||
@ -30,8 +30,9 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
try:
|
||||
kb.create_kb()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
|
||||
msg = f"创建知识库出错: {e}"
|
||||
logger.error(msg)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
@ -55,7 +56,8 @@ async def delete_kb(
|
||||
if status:
|
||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
|
||||
msg = f"删除知识库时出现意外: {e}"
|
||||
logger.error(msg)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
import os
|
||||
import urllib
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
logger,)
|
||||
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
||||
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
|
||||
files2docs_in_thread, KnowledgeFile)
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from pydantic import Json
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from typing import List, Dict
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
@ -44,11 +49,83 @@ async def list_files(
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
||||
def _save_files_in_thread(files: List[UploadFile],
|
||||
knowledge_base_name: str,
|
||||
override: bool):
|
||||
'''
|
||||
通过多线程将上传的文件保存到对应知识库目录内。
|
||||
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||||
'''
|
||||
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
|
||||
'''
|
||||
保存单个文件。
|
||||
'''
|
||||
try:
|
||||
filename = file.filename
|
||||
file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
|
||||
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
|
||||
|
||||
file_content = file.file.read() # 读取上传文件的内容
|
||||
if (os.path.isfile(file_path)
|
||||
and not override
|
||||
and os.path.getsize(file_path) == len(file_content)
|
||||
):
|
||||
# TODO: filesize 不同后的处理
|
||||
file_status = f"文件 {filename} 已存在。"
|
||||
logger.warn(file_status)
|
||||
return dict(code=404, msg=file_status, data=data)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
|
||||
except Exception as e:
|
||||
msg = f"{filename} 文件上传失败,报错信息为: {e}"
|
||||
logger.error(msg)
|
||||
return dict(code=500, msg=msg, data=data)
|
||||
|
||||
params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
|
||||
for result in run_in_thread_pool(save_file, params=params):
|
||||
yield result
|
||||
|
||||
|
||||
# 似乎没有单独增加一个文件上传API接口的必要
|
||||
# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
# override: bool = Form(False, description="覆盖已有文件")):
|
||||
# '''
|
||||
# API接口:上传文件。流式返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||||
# '''
|
||||
# def generate(files, knowledge_base_name, override):
|
||||
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream")
|
||||
|
||||
|
||||
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
|
||||
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
# override: bool = Form(False, description="覆盖已有文件"),
|
||||
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
|
||||
# def save_files(files, knowledge_base_name, override):
|
||||
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
# def files_to_docs(files):
|
||||
# for result in files2docs_in_thread(files):
|
||||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||||
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
API接口:上传文件,并/或向量化
|
||||
'''
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
@ -56,37 +133,36 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
file_content = await file.read() # 读取上传文件的内容
|
||||
failed_files = {}
|
||||
file_names = list(docs.keys())
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file.filename,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
# 先将上传的文件保存到磁盘
|
||||
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
filename = result["data"]["file_name"]
|
||||
if result["code"] != 200:
|
||||
failed_files[filename] = result["msg"]
|
||||
|
||||
if filename not in file_names:
|
||||
file_names.append(filename)
|
||||
|
||||
if (os.path.exists(kb_file.filepath)
|
||||
and not override
|
||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
||||
):
|
||||
# TODO: filesize 不同后的处理
|
||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
||||
return BaseResponse(code=404, msg=file_status)
|
||||
# 对保存的文件进行向量化
|
||||
if to_vector_store:
|
||||
result = await update_docs(
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
file_names=file_names,
|
||||
override_custom_docs=True,
|
||||
docs=docs,
|
||||
not_refresh_vs_cache=True,
|
||||
)
|
||||
failed_files.update(result.data["failed_files"])
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
with open(kb_file.filepath, "wb") as f:
|
||||
f.write(file_content)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
||||
|
||||
try:
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_name: str = Body(..., examples=["file_name.md"]),
|
||||
async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
||||
delete_content: bool = Body(False),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
@ -98,23 +174,31 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
if not kb.exist_doc(doc_name):
|
||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
||||
failed_files = {}
|
||||
for file_name in file_names:
|
||||
if not kb.exist_doc(file_name):
|
||||
failed_files[file_name] = f"未找到文件 {file_name}"
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=doc_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
|
||||
except Exception as e:
|
||||
msg = f"{file_name} 文件删除失败,错误信息:{e}"
|
||||
logger.error(msg)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
|
||||
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def update_doc(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_name: str = Body(..., examples=["file_name"]),
|
||||
async def update_docs(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
|
||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
@ -127,22 +211,57 @@ async def update_doc(
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
if os.path.exists(kb_file.filepath):
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
|
||||
failed_files = {}
|
||||
kb_files = []
|
||||
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
||||
# 生成需要加载docs的文件列表
|
||||
for file_name in file_names:
|
||||
file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name)
|
||||
# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
|
||||
if file_detail.get("custom_docs") and not override_custom_docs:
|
||||
continue
|
||||
if file_name not in docs:
|
||||
try:
|
||||
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
|
||||
except Exception as e:
|
||||
msg = f"加载文档 {file_name} 时出错:{e}"
|
||||
logger.error(msg)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
# 从文件生成docs,并进行向量化。
|
||||
# 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile
|
||||
for status, result in files2docs_in_thread(kb_files):
|
||||
if status:
|
||||
kb_name, file_name, new_docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb_file.splited_docs = new_docs
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
failed_files[file_name] = error
|
||||
|
||||
# 将自定义的docs进行向量化
|
||||
for file_name, v in docs.items():
|
||||
try:
|
||||
v = [x if isinstance(x, Document) else Document(**x) for x in v]
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
|
||||
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
|
||||
except Exception as e:
|
||||
msg = f"为 {file_name} 添加自定义docs时出错:{e}"
|
||||
logger.error(msg)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def download_doc(
|
||||
knowledge_base_name: str = Query(..., examples=["samples"]),
|
||||
file_name: str = Query(..., examples=["test.txt"]),
|
||||
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
|
||||
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
|
||||
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
|
||||
):
|
||||
'''
|
||||
下载知识库文档
|
||||
@ -154,6 +273,11 @@ async def download_doc(
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
if preview:
|
||||
content_disposition_type = "inline"
|
||||
else:
|
||||
content_disposition_type = None
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
@ -162,10 +286,13 @@ async def download_doc(
|
||||
return FileResponse(
|
||||
path=kb_file.filepath,
|
||||
filename=kb_file.filename,
|
||||
media_type="multipart/form-data")
|
||||
media_type="multipart/form-data",
|
||||
content_disposition_type=content_disposition_type,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
|
||||
msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
|
||||
logger.error(msg)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||
|
||||
@ -190,27 +317,30 @@ async def recreate_vector_store(
|
||||
else:
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
docs = list_files_from_folder(knowledge_base_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
for status, result in files2docs_in_thread(kb_files):
|
||||
if status:
|
||||
kb_name, file_name, docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||
kb_file.splited_docs = docs
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(docs)}): {doc}",
|
||||
"total": len(docs),
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i,
|
||||
"doc": doc,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
|
||||
return StreamingResponse(output(), media_type="text/event-stream")
|
||||
|
||||
@ -49,6 +49,13 @@ class KBService(ABC):
|
||||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||
return load_embeddings(self.embed_model, embed_device)
|
||||
|
||||
def save_vector_store(self, vector_store=None):
|
||||
'''
|
||||
保存向量库,仅支持FAISS。对于其它向量库该函数不做任何操作。
|
||||
减少FAISS向量库操作时的类型判断。
|
||||
'''
|
||||
pass
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
@ -82,6 +89,8 @@ class KBService(ABC):
|
||||
"""
|
||||
if docs:
|
||||
custom_docs = True
|
||||
for doc in docs:
|
||||
doc.metadata.setdefault("source", kb_file.filepath)
|
||||
else:
|
||||
docs = kb_file.file2text()
|
||||
custom_docs = False
|
||||
|
||||
@ -5,7 +5,8 @@ from configs.model_config import (
|
||||
KB_ROOT_PATH,
|
||||
CACHED_VS_NUM,
|
||||
EMBEDDING_MODEL,
|
||||
SCORE_THRESHOLD
|
||||
SCORE_THRESHOLD,
|
||||
logger,
|
||||
)
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from functools import lru_cache
|
||||
@ -28,7 +29,7 @@ def load_faiss_vector_store(
|
||||
embeddings: Embeddings = None,
|
||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
) -> FAISS:
|
||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
||||
logger.info(f"loading vector store in '{knowledge_base_name}'.")
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
if embeddings is None:
|
||||
embeddings = load_embeddings(embed_model, embed_device)
|
||||
@ -57,7 +58,7 @@ def refresh_vs_cache(kb_name: str):
|
||||
make vector store cache refreshed when next loading
|
||||
"""
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
@ -128,7 +129,7 @@ class FaissKBService(KBService):
|
||||
**kwargs):
|
||||
vector_store = self.load_vector_store()
|
||||
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ from configs.model_config import (
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE
|
||||
ZH_TITLE_ENHANCE,
|
||||
logger,
|
||||
)
|
||||
from functools import lru_cache
|
||||
import importlib
|
||||
@ -19,6 +20,7 @@ from pathlib import Path
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from server.utils import run_in_thread_pool
|
||||
import io
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
||||
|
||||
@ -175,12 +177,74 @@ def get_LoaderClass(file_extension):
|
||||
return LoaderClass
|
||||
|
||||
|
||||
# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
|
||||
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
|
||||
'''
|
||||
根据loader_name和文件路径或内容返回文档加载器。
|
||||
'''
|
||||
try:
|
||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||
except Exception as e:
|
||||
logger.error(f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}")
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
|
||||
if loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
|
||||
elif loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
|
||||
elif loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
||||
elif loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(file_path_or_content, text_content=False)
|
||||
elif loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(file_path_or_content)
|
||||
return loader
|
||||
|
||||
|
||||
def make_text_splitter(
|
||||
splitter_name: str = "SpacyTextSplitter",
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
):
|
||||
'''
|
||||
根据参数获取特定的分词器
|
||||
'''
|
||||
splitter_name = splitter_name or "SpacyTextSplitter"
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
try:
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查找分词器 {splitter_name} 时出错:{e}")
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
return text_splitter
|
||||
|
||||
class KnowledgeFile:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
knowledge_base_name: str
|
||||
):
|
||||
'''
|
||||
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
||||
'''
|
||||
self.kb_name = knowledge_base_name
|
||||
self.filename = filename
|
||||
self.ext = os.path.splitext(filename)[-1].lower()
|
||||
@ -196,65 +260,11 @@ class KnowledgeFile:
|
||||
|
||||
def file2docs(self, refresh: bool=False):
|
||||
if self.docs is None or refresh:
|
||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
||||
try:
|
||||
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
elif self.document_loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||
elif self.document_loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||
elif self.document_loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(self.filepath, text_content=False)
|
||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||
loader = get_loader(self.document_loader_name, self.filepath)
|
||||
self.docs = loader.load()
|
||||
return self.docs
|
||||
|
||||
def make_text_splitter(
|
||||
self,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
):
|
||||
try:
|
||||
if self.text_splitter_name is None:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
self.text_splitter_name = "SpacyTextSplitter"
|
||||
else:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
return text_splitter
|
||||
|
||||
def docs2texts(
|
||||
self,
|
||||
docs: List[Document] = None,
|
||||
@ -265,10 +275,11 @@ class KnowledgeFile:
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
docs = docs or self.file2docs(refresh=refresh)
|
||||
|
||||
if not docs:
|
||||
return []
|
||||
if self.ext not in [".csv"]:
|
||||
if text_splitter is None:
|
||||
text_splitter = self.make_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
print(f"文档切分示例:{docs[0]}")
|
||||
@ -286,13 +297,18 @@ class KnowledgeFile:
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
if self.splited_docs is None or refresh:
|
||||
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
|
||||
docs = self.file2docs()
|
||||
self.splited_docs = self.docs2texts(docs=docs,
|
||||
using_zh_title_enhance=using_zh_title_enhance,
|
||||
refresh=refresh,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
text_splitter=text_splitter)
|
||||
return self.splited_docs
|
||||
|
||||
def file_exist(self):
|
||||
return os.path.isfile(self.filepath)
|
||||
|
||||
def get_mtime(self):
|
||||
return os.path.getmtime(self.filepath)
|
||||
|
||||
@ -301,18 +317,21 @@ class KnowledgeFile:
|
||||
|
||||
|
||||
def files2docs_in_thread(
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name)
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
利用多线程批量将磁盘文件转化成langchain Document.
|
||||
生成器返回值为{(kb_name, file_name): docs}
|
||||
如果传入参数是Tuple,形式为(filename, kb_name)
|
||||
生成器返回值为 status, (kb_name, file_name, docs | error)
|
||||
'''
|
||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||
try:
|
||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||
except Exception as e:
|
||||
return False, e
|
||||
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
|
||||
logger.error(msg)
|
||||
return False, (file.kb_name, file.filename, msg)
|
||||
|
||||
kwargs_list = []
|
||||
for i, file in enumerate(files):
|
||||
|
||||
431
server/knowledge_base/utils.py.bak
Normal file
431
server/knowledge_base/utils.py.bak
Normal file
@ -0,0 +1,431 @@
|
||||
import os
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from configs.model_config import (
|
||||
embedding_model_dict,
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE,
|
||||
logger,
|
||||
)
|
||||
from functools import lru_cache
|
||||
import importlib
|
||||
from text_splitter import zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
import document_loaders
|
||||
import unstructured.partition
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from server.utils import run_in_thread_pool
|
||||
import io
|
||||
import builtins
|
||||
from datetime import datetime
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
if isinstance(self, HuggingFaceEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, HuggingFaceBgeEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, OpenAIEmbeddings):
|
||||
return hash(self.model)
|
||||
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
OpenAIEmbeddings.__hash__ = _embeddings_hash
|
||||
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
|
||||
# patch langchain.document_loaders和项目自定义document_loaders,替换其中的open函数。
|
||||
# 使其支持对str,bytes,io.StringIO,io.BytesIO进行向量化
|
||||
def _new_open(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], *args, **kw):
|
||||
if isinstance(content, (io.StringIO, io.BytesIO)):
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
if os.path.isfile(content):
|
||||
return builtins.open(content, *args, **kw)
|
||||
else:
|
||||
return io.StringIO(content)
|
||||
if isinstance(content, bytes):
|
||||
return io.BytesIO(bytes)
|
||||
if isinstance(content, Path):
|
||||
return Path.open(*args, **kw)
|
||||
return open(content, *args, **kw)
|
||||
|
||||
for module in [langchain.document_loaders, document_loaders]:
|
||||
for k, v in module.__dict__.items():
|
||||
if type(v) == type(langchain.document_loaders):
|
||||
v.open = _new_open
|
||||
|
||||
# path unstructured 使其在处理非磁盘文件时不会出错
|
||||
def _new_get_last_modified_date(filename: str) -> Union[str, None]:
|
||||
try:
|
||||
modify_date = datetime.fromtimestamp(os.path.getmtime(filename))
|
||||
return modify_date.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
except:
|
||||
return None
|
||||
|
||||
for k, v in unstructured.partition.__dict__.items():
|
||||
if type(v) == type(unstructured.partition):
|
||||
v.open = _new_open
|
||||
v.get_last_modified_date = _new_get_last_modified_date
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
if "../" in knowledge_base_id:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
|
||||
|
||||
def get_doc_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||
|
||||
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
|
||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||
|
||||
|
||||
def list_kbs_from_folder():
|
||||
return [f for f in os.listdir(KB_ROOT_PATH)
|
||||
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
|
||||
|
||||
|
||||
def list_files_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device},
|
||||
query_instruction="为这个句子生成表示以用于检索相关文章:")
|
||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||
embeddings.query_instruction = ""
|
||||
else:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
"UnstructuredMarkdownLoader": ['.md'],
|
||||
"CustomJSONLoader": [".json"],
|
||||
"CSVLoader": [".csv"],
|
||||
"RapidOCRPDFLoader": [".pdf"],
|
||||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt',
|
||||
'.ppt', '.pptx', '.tsv'], # '.xlsx'
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
|
||||
|
||||
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||
'''
|
||||
langchain的JSONLoader需要jq,在win上使用不便,进行替代。
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
json_lines: bool = False,
|
||||
):
|
||||
"""Initialize the JSONLoader.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
|
||||
content_key (str): The key to use to extract the content from the JSON if
|
||||
results to a list of objects (dict).
|
||||
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
|
||||
object extracted by the jq_schema and the default metadata and returns
|
||||
a dict of the updated metadata.
|
||||
text_content (bool): Boolean flag to indicate whether the content is in
|
||||
string format, default to True.
|
||||
json_lines (bool): Boolean flag to indicate whether the input is in
|
||||
JSON Lines format.
|
||||
"""
|
||||
self.file_path = Path(file_path).resolve()
|
||||
self._content_key = content_key
|
||||
self._metadata_func = metadata_func
|
||||
self._text_content = text_content
|
||||
self._json_lines = json_lines
|
||||
|
||||
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
|
||||
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
|
||||
def load(self) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
if self._json_lines:
|
||||
with self.file_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._parse(line, docs)
|
||||
else:
|
||||
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, content: str, docs: List[Document]) -> None:
|
||||
"""Convert given content to documents."""
|
||||
data = json.loads(content)
|
||||
|
||||
# Perform some validation
|
||||
# This is not a perfect validation, but it should catch most cases
|
||||
# and prevent the user from getting a cryptic error later on.
|
||||
if self._content_key is not None:
|
||||
self._validate_content_key(data)
|
||||
|
||||
for i, sample in enumerate(data, len(docs) + 1):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=i,
|
||||
)
|
||||
text = self._get_text(sample=sample, metadata=metadata)
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
|
||||
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
|
||||
|
||||
|
||||
def get_LoaderClass(file_extension):
|
||||
for LoaderClass, extensions in LOADER_DICT.items():
|
||||
if file_extension in extensions:
|
||||
return LoaderClass
|
||||
|
||||
|
||||
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
|
||||
'''
|
||||
根据loader_name和文件路径或内容返回文档加载器。
|
||||
'''
|
||||
try:
|
||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
|
||||
if loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
|
||||
elif loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
|
||||
elif loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
||||
elif loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(file_path_or_content, text_content=False)
|
||||
elif loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
elif loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(file_path_or_content)
|
||||
return loader
|
||||
|
||||
|
||||
def make_text_splitter(
|
||||
splitter_name: str = "SpacyTextSplitter",
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
):
|
||||
'''
|
||||
根据参数获取特定的分词器
|
||||
'''
|
||||
splitter_name = splitter_name or "SpacyTextSplitter"
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
try:
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
return text_splitter
|
||||
|
||||
|
||||
def content_to_docs(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], ext: str = ".md") -> List[Document]:
|
||||
'''
|
||||
将磁盘文件、文本、字节、内存文件等转化成Document
|
||||
'''
|
||||
if not ext.startswith("."):
|
||||
ext = "." + ext
|
||||
ext = ext.lower()
|
||||
if ext not in SUPPORTED_EXTS:
|
||||
raise ValueError(f"暂未支持的文件格式 {ext}")
|
||||
|
||||
loader_name = get_LoaderClass(ext)
|
||||
loader = get_loader(loader_name=loader_name, file_path_or_content=content)
|
||||
return loader.load()
|
||||
|
||||
|
||||
def split_docs(
|
||||
docs: List[Document],
|
||||
splitter_name: str = "SpacyTextSplitter",
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
) -> List[Document]:
|
||||
text_splitter = make_text_splitter(splitter_name=splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
return text_splitter.split_documents(docs)
|
||||
|
||||
|
||||
class KnowledgeFile:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
knowledge_base_name: str
|
||||
):
|
||||
'''
|
||||
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
||||
'''
|
||||
self.kb_name = knowledge_base_name
|
||||
self.filename = filename
|
||||
self.ext = os.path.splitext(filename)[-1].lower()
|
||||
if self.ext not in SUPPORTED_EXTS:
|
||||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||
self.docs = None
|
||||
self.splited_docs = None
|
||||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = None
|
||||
|
||||
def file2docs(self, refresh: bool=False):
|
||||
if self.docs is None or refresh:
|
||||
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||
loader = get_loader(self.document_loader_name, self.filepath)
|
||||
self.docs = loader.load()
|
||||
return self.docs
|
||||
|
||||
def docs2texts(
|
||||
self,
|
||||
docs: List[Document] = None,
|
||||
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||
refresh: bool = False,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
docs = docs or self.file2docs(refresh=refresh)
|
||||
|
||||
if self.ext not in [".csv"]:
|
||||
if text_splitter is None:
|
||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
print(f"文档切分示例:{docs[0]}")
|
||||
if using_zh_title_enhance:
|
||||
docs = zh_title_enhance(docs)
|
||||
self.splited_docs = docs
|
||||
return self.splited_docs
|
||||
|
||||
def file2text(
|
||||
self,
|
||||
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||
refresh: bool = False,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
if self.splited_docs is None or refresh:
|
||||
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
|
||||
refresh=refresh,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
text_splitter=text_splitter)
|
||||
return self.splited_docs
|
||||
|
||||
def file_exist(self):
|
||||
return os.path.isfile(self.filepath)
|
||||
|
||||
def get_mtime(self):
|
||||
return os.path.getmtime(self.filepath)
|
||||
|
||||
def get_size(self):
|
||||
return os.path.getsize(self.filepath)
|
||||
|
||||
|
||||
def files2docs_in_thread(
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name)
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
利用多线程批量将磁盘文件转化成langchain Document.
|
||||
生成器返回值为 status, (kb_name, file_name, docs)
|
||||
'''
|
||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||
try:
|
||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||
except Exception as e:
|
||||
return False, e
|
||||
|
||||
kwargs_list = []
|
||||
for i, file in enumerate(files):
|
||||
kwargs = {}
|
||||
if isinstance(file, tuple) and len(file) >= 2:
|
||||
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||
elif isinstance(file, dict):
|
||||
filename = file.pop("filename")
|
||||
kb_name = file.pop("kb_name")
|
||||
kwargs = file
|
||||
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kwargs["file"] = file
|
||||
kwargs_list.append(kwargs)
|
||||
|
||||
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
|
||||
yield result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pprint import pprint
|
||||
|
||||
# kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
||||
# # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
||||
# docs = kb_file.file2docs()
|
||||
# pprint(docs[-1])
|
||||
|
||||
# docs = kb_file.file2text()
|
||||
# pprint(docs[-1])
|
||||
|
||||
docs = content_to_docs("""
|
||||
## this is a title
|
||||
|
||||
## another title
|
||||
|
||||
how are you
|
||||
this a wonderful day.
|
||||
""", "txt")
|
||||
pprint(docs)
|
||||
pprint(split_docs(docs, chunk_size=10, chunk_overlap=0))
|
||||
@ -7,17 +7,21 @@ root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
api = ApiRequest(api_base_url)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||
"README.MD": str(root_path / "README.MD"),
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
|
||||
"test.txt": get_file_path("samples", "test.txt"),
|
||||
}
|
||||
|
||||
|
||||
@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
|
||||
assert kb in data["data"]
|
||||
|
||||
|
||||
def test_upload_doc(api="/knowledge_base/upload_doc"):
|
||||
def test_upload_docs(api="/knowledge_base/upload_docs"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n上传知识文件: {name}")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功上传文件 {name}"
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
|
||||
for name, path in test_files.items():
|
||||
print(f"\n尝试重新上传知识文件: {name}, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"文件 {name} 已存在。"
|
||||
print(f"\n上传知识文件")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
for name, path in test_files.items():
|
||||
print(f"\n尝试重新上传知识文件: {name}, 覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
files = {"file": (name, open(path, "rb"))}
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功上传文件 {name}"
|
||||
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||
r = requests.post(url, data=data, files=files)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_list_files(api="/knowledge_base/list_files"):
|
||||
@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"):
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_doc(api="/knowledge_base/update_doc"):
|
||||
def test_update_docs(api="/knowledge_base/update_docs"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n更新知识文件: {name}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"成功更新文件 {name}"
|
||||
|
||||
print(f"\n更新知识文件")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_delete_doc(api="/knowledge_base/delete_doc"):
|
||||
def test_delete_docs(api="/knowledge_base/delete_docs"):
|
||||
url = api_base_url + api
|
||||
for name, path in test_files.items():
|
||||
print(f"\n删除知识文件: {name}")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"{name} 文件删除成功"
|
||||
|
||||
print(f"\n删除知识文件")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
url = api_base_url + "/knowledge_base/search_docs"
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
|
||||
@ -21,9 +21,7 @@ from fastapi.responses import StreamingResponse
|
||||
import contextlib
|
||||
import json
|
||||
from io import BytesIO
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from server.utils import run_async, iter_over_async, set_httpx_timeout
|
||||
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
|
||||
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
import nltk
|
||||
@ -43,7 +41,7 @@ class ApiRequest:
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://127.0.0.1:7861",
|
||||
base_url: str = api_address(),
|
||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||
no_remote_api: bool = False, # call api view function directly
|
||||
):
|
||||
@ -78,7 +76,7 @@ class ApiRequest:
|
||||
else:
|
||||
return httpx.get(url, params=params, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when get {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
async def aget(
|
||||
@ -99,7 +97,7 @@ class ApiRequest:
|
||||
else:
|
||||
return await client.get(url, params=params, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when aget {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
def post(
|
||||
@ -121,7 +119,7 @@ class ApiRequest:
|
||||
else:
|
||||
return httpx.post(url, data=data, json=json, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when post {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
async def apost(
|
||||
@ -143,7 +141,7 @@ class ApiRequest:
|
||||
else:
|
||||
return await client.post(url, data=data, json=json, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when apost {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
def delete(
|
||||
@ -164,7 +162,7 @@ class ApiRequest:
|
||||
else:
|
||||
return httpx.delete(url, data=data, json=json, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when delete {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
async def adelete(
|
||||
@ -186,7 +184,7 @@ class ApiRequest:
|
||||
else:
|
||||
return await client.delete(url, data=data, json=json, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when adelete {url}: {e}")
|
||||
retry -= 1
|
||||
|
||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||
@ -205,7 +203,7 @@ class ApiRequest:
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(f"error when run fastapi router: {e}")
|
||||
|
||||
def _httpx_stream2generator(
|
||||
self,
|
||||
@ -231,18 +229,18 @@ class ApiRequest:
|
||||
print(chunk, end="", flush=True)
|
||||
yield chunk
|
||||
except httpx.ConnectError as e:
|
||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||||
logger.error(msg)
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "msg": msg}
|
||||
except httpx.ReadTimeout as e:
|
||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')"
|
||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')。({e})"
|
||||
logger.error(msg)
|
||||
logger.error(e)
|
||||
yield {"code": 500, "msg": msg}
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
yield {"code": 500, "msg": str(e)}
|
||||
msg = f"API通信遇到错误:{e}"
|
||||
logger.error(msg)
|
||||
yield {"code": 500, "msg": msg}
|
||||
|
||||
# 对话相关操作
|
||||
|
||||
@ -413,8 +411,9 @@ class ApiRequest:
|
||||
try:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return {"code": 500, "msg": errorMsg or str(e)}
|
||||
msg = "API未能返回正确的JSON。" + (errorMsg or str(e))
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
def list_knowledge_bases(
|
||||
self,
|
||||
@ -510,12 +509,13 @@ class ApiRequest:
|
||||
data = self._check_httpx_json_response(response)
|
||||
return data.get("data", [])
|
||||
|
||||
def upload_kb_doc(
|
||||
def upload_kb_docs(
|
||||
self,
|
||||
file: Union[str, Path, bytes],
|
||||
files: List[Union[str, Path, bytes]],
|
||||
knowledge_base_name: str,
|
||||
filename: str = None,
|
||||
override: bool = False,
|
||||
to_vector_store: bool = True,
|
||||
docs: List[Dict] = [],
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user