mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 23:43:30 +08:00
新功能:
- 知识库管理中的add_docs/delete_docs/update_docs均支持批量操作,并利用多线程提高效率 - API的重建知识库接口支持多线程 - add_docs可提供参数控制上传文件后是否继续进行向量化 - add_docs/update_docs支持传入自定义docs(以json形式)。后续考虑区分完整或补充式自定义docs - download_doc接口添加`preview`参数,支持下载或预览 - kb_service增加`save_vector_store`方法,便于保存向量库(仅FAISS,其它无操作) - 将document_loader & text_splitter逻辑从KnowledgeFile中抽离出来,为后续对内存文件进行向量化做准备 - KowledgeFile支持docs & splitted_docs的缓存,方便在中间过程做一些自定义 其它: - 将部分错误输出由print改为logger.error
This commit is contained in:
parent
93b133f9ac
commit
661a0e9d72
@ -15,8 +15,8 @@ from starlette.responses import RedirectResponse
|
|||||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||||
search_engine_chat)
|
search_engine_chat)
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
update_doc, download_doc, recreate_vector_store,
|
update_docs, download_doc, recreate_vector_store,
|
||||||
search_docs, DocumentWithScore)
|
search_docs, DocumentWithScore)
|
||||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
|
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
|
||||||
import httpx
|
import httpx
|
||||||
@ -98,23 +98,23 @@ def create_app():
|
|||||||
summary="搜索知识库"
|
summary="搜索知识库"
|
||||||
)(search_docs)
|
)(search_docs)
|
||||||
|
|
||||||
app.post("/knowledge_base/upload_doc",
|
app.post("/knowledge_base/upload_docs",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="上传文件到知识库"
|
summary="上传文件到知识库,并/或进行向量化"
|
||||||
)(upload_doc)
|
)(upload_docs)
|
||||||
|
|
||||||
app.post("/knowledge_base/delete_doc",
|
app.post("/knowledge_base/delete_docs",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="删除知识库内指定文件"
|
summary="删除知识库内指定文件"
|
||||||
)(delete_doc)
|
)(delete_docs)
|
||||||
|
|
||||||
app.post("/knowledge_base/update_doc",
|
app.post("/knowledge_base/update_docs",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="更新现有文件到知识库"
|
summary="更新现有文件到知识库"
|
||||||
)(update_doc)
|
)(update_docs)
|
||||||
|
|
||||||
app.get("/knowledge_base/download_doc",
|
app.get("/knowledge_base/download_doc",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
|
|||||||
from server.knowledge_base.utils import validate_kb_name
|
from server.knowledge_base.utils import validate_kb_name
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||||
from configs.model_config import EMBEDDING_MODEL
|
from configs.model_config import EMBEDDING_MODEL, logger
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
||||||
|
|
||||||
@ -30,8 +30,9 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||||||
try:
|
try:
|
||||||
kb.create_kb()
|
kb.create_kb()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
msg = f"创建知识库出错: {e}"
|
||||||
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
|
logger.error(msg)
|
||||||
|
return BaseResponse(code=500, msg=msg)
|
||||||
|
|
||||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
@ -55,7 +56,8 @@ async def delete_kb(
|
|||||||
if status:
|
if status:
|
||||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
msg = f"删除知识库时出现意外: {e}"
|
||||||
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
|
logger.error(msg)
|
||||||
|
return BaseResponse(code=500, msg=msg)
|
||||||
|
|
||||||
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
|
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
|
||||||
|
|||||||
@ -1,12 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
from fastapi import File, Form, Body, Query, UploadFile
|
from fastapi import File, Form, Body, Query, UploadFile
|
||||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||||
from server.utils import BaseResponse, ListResponse
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
|
logger,)
|
||||||
|
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
||||||
|
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
|
||||||
|
files2docs_in_thread, KnowledgeFile)
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
|
from pydantic import Json
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
|
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
@ -44,11 +49,83 @@ async def list_files(
|
|||||||
return ListResponse(data=all_doc_names)
|
return ListResponse(data=all_doc_names)
|
||||||
|
|
||||||
|
|
||||||
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
def _save_files_in_thread(files: List[UploadFile],
|
||||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
knowledge_base_name: str,
|
||||||
|
override: bool):
|
||||||
|
'''
|
||||||
|
通过多线程将上传的文件保存到对应知识库目录内。
|
||||||
|
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||||||
|
'''
|
||||||
|
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
|
||||||
|
'''
|
||||||
|
保存单个文件。
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
filename = file.filename
|
||||||
|
file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
|
||||||
|
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
|
||||||
|
|
||||||
|
file_content = file.file.read() # 读取上传文件的内容
|
||||||
|
if (os.path.isfile(file_path)
|
||||||
|
and not override
|
||||||
|
and os.path.getsize(file_path) == len(file_content)
|
||||||
|
):
|
||||||
|
# TODO: filesize 不同后的处理
|
||||||
|
file_status = f"文件 {filename} 已存在。"
|
||||||
|
logger.warn(file_status)
|
||||||
|
return dict(code=404, msg=file_status, data=data)
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(file_content)
|
||||||
|
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"{filename} 文件上传失败,报错信息为: {e}"
|
||||||
|
logger.error(msg)
|
||||||
|
return dict(code=500, msg=msg, data=data)
|
||||||
|
|
||||||
|
params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
|
||||||
|
for result in run_in_thread_pool(save_file, params=params):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
# 似乎没有单独增加一个文件上传API接口的必要
|
||||||
|
# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
|
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||||
|
# override: bool = Form(False, description="覆盖已有文件")):
|
||||||
|
# '''
|
||||||
|
# API接口:上传文件。流式返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||||||
|
# '''
|
||||||
|
# def generate(files, knowledge_base_name, override):
|
||||||
|
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||||
|
# yield json.dumps(result, ensure_ascii=False)
|
||||||
|
|
||||||
|
# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
|
||||||
|
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
|
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||||
|
# override: bool = Form(False, description="覆盖已有文件"),
|
||||||
|
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
|
||||||
|
# def save_files(files, knowledge_base_name, override):
|
||||||
|
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||||
|
# yield json.dumps(result, ensure_ascii=False)
|
||||||
|
|
||||||
|
# def files_to_docs(files):
|
||||||
|
# for result in files2docs_in_thread(files):
|
||||||
|
# yield json.dumps(result, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
|
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||||
override: bool = Form(False, description="覆盖已有文件"),
|
override: bool = Form(False, description="覆盖已有文件"),
|
||||||
|
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||||||
|
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
API接口:上传文件,并/或向量化
|
||||||
|
'''
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
@ -56,37 +133,36 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
|||||||
if kb is None:
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
file_content = await file.read() # 读取上传文件的内容
|
failed_files = {}
|
||||||
|
file_names = list(docs.keys())
|
||||||
|
|
||||||
try:
|
# 先将上传的文件保存到磁盘
|
||||||
kb_file = KnowledgeFile(filename=file.filename,
|
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||||
knowledge_base_name=knowledge_base_name)
|
filename = result["data"]["file_name"]
|
||||||
|
if result["code"] != 200:
|
||||||
|
failed_files[filename] = result["msg"]
|
||||||
|
|
||||||
if (os.path.exists(kb_file.filepath)
|
if filename not in file_names:
|
||||||
and not override
|
file_names.append(filename)
|
||||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
|
||||||
):
|
|
||||||
# TODO: filesize 不同后的处理
|
|
||||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
|
||||||
return BaseResponse(code=404, msg=file_status)
|
|
||||||
|
|
||||||
with open(kb_file.filepath, "wb") as f:
|
# 对保存的文件进行向量化
|
||||||
f.write(file_content)
|
if to_vector_store:
|
||||||
except Exception as e:
|
result = await update_docs(
|
||||||
print(e)
|
knowledge_base_name=knowledge_base_name,
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
|
file_names=file_names,
|
||||||
|
override_custom_docs=True,
|
||||||
|
docs=docs,
|
||||||
|
not_refresh_vs_cache=True,
|
||||||
|
)
|
||||||
|
failed_files.update(result.data["failed_files"])
|
||||||
|
if not not_refresh_vs_cache:
|
||||||
|
kb.save_vector_store()
|
||||||
|
|
||||||
try:
|
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
|
||||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
|
|
||||||
|
|
||||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
doc_name: str = Body(..., examples=["file_name.md"]),
|
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
||||||
delete_content: bool = Body(False),
|
delete_content: bool = Body(False),
|
||||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
@ -98,23 +174,31 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||||||
if kb is None:
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
if not kb.exist_doc(doc_name):
|
failed_files = {}
|
||||||
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
|
for file_name in file_names:
|
||||||
|
if not kb.exist_doc(file_name):
|
||||||
|
failed_files[file_name] = f"未找到文件 {file_name}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kb_file = KnowledgeFile(filename=doc_name,
|
kb_file = KnowledgeFile(filename=file_name,
|
||||||
knowledge_base_name=knowledge_base_name)
|
knowledge_base_name=knowledge_base_name)
|
||||||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
|
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
msg = f"{file_name} 文件删除失败,错误信息:{e}"
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
|
logger.error(msg)
|
||||||
|
failed_files[file_name] = msg
|
||||||
|
|
||||||
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
|
if not not_refresh_vs_cache:
|
||||||
|
kb.save_vector_store()
|
||||||
|
|
||||||
|
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
||||||
|
|
||||||
|
|
||||||
async def update_doc(
|
async def update_docs(
|
||||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
file_name: str = Body(..., examples=["file_name"]),
|
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
|
||||||
|
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||||
|
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
'''
|
'''
|
||||||
@ -127,22 +211,57 @@ async def update_doc(
|
|||||||
if kb is None:
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
try:
|
failed_files = {}
|
||||||
kb_file = KnowledgeFile(filename=file_name,
|
kb_files = []
|
||||||
knowledge_base_name=knowledge_base_name)
|
|
||||||
if os.path.exists(kb_file.filepath):
|
|
||||||
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
|
||||||
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
|
|
||||||
|
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
|
# 生成需要加载docs的文件列表
|
||||||
|
for file_name in file_names:
|
||||||
|
file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name)
|
||||||
|
# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
|
||||||
|
if file_detail.get("custom_docs") and not override_custom_docs:
|
||||||
|
continue
|
||||||
|
if file_name not in docs:
|
||||||
|
try:
|
||||||
|
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"加载文档 {file_name} 时出错:{e}"
|
||||||
|
logger.error(msg)
|
||||||
|
failed_files[file_name] = msg
|
||||||
|
|
||||||
|
# 从文件生成docs,并进行向量化。
|
||||||
|
# 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile
|
||||||
|
for status, result in files2docs_in_thread(kb_files):
|
||||||
|
if status:
|
||||||
|
kb_name, file_name, new_docs = result
|
||||||
|
kb_file = KnowledgeFile(filename=file_name,
|
||||||
|
knowledge_base_name=knowledge_base_name)
|
||||||
|
kb_file.splited_docs = new_docs
|
||||||
|
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||||
|
else:
|
||||||
|
kb_name, file_name, error = result
|
||||||
|
failed_files[file_name] = error
|
||||||
|
|
||||||
|
# 将自定义的docs进行向量化
|
||||||
|
for file_name, v in docs.items():
|
||||||
|
try:
|
||||||
|
v = [x if isinstance(x, Document) else Document(**x) for x in v]
|
||||||
|
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
|
||||||
|
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"为 {file_name} 添加自定义docs时出错:{e}"
|
||||||
|
logger.error(msg)
|
||||||
|
failed_files[file_name] = msg
|
||||||
|
|
||||||
|
if not not_refresh_vs_cache:
|
||||||
|
kb.save_vector_store()
|
||||||
|
|
||||||
|
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
|
||||||
|
|
||||||
|
|
||||||
async def download_doc(
|
async def download_doc(
|
||||||
knowledge_base_name: str = Query(..., examples=["samples"]),
|
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
|
||||||
file_name: str = Query(..., examples=["test.txt"]),
|
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
|
||||||
|
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
下载知识库文档
|
下载知识库文档
|
||||||
@ -154,6 +273,11 @@ async def download_doc(
|
|||||||
if kb is None:
|
if kb is None:
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
|
if preview:
|
||||||
|
content_disposition_type = "inline"
|
||||||
|
else:
|
||||||
|
content_disposition_type = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kb_file = KnowledgeFile(filename=file_name,
|
kb_file = KnowledgeFile(filename=file_name,
|
||||||
knowledge_base_name=knowledge_base_name)
|
knowledge_base_name=knowledge_base_name)
|
||||||
@ -162,10 +286,13 @@ async def download_doc(
|
|||||||
return FileResponse(
|
return FileResponse(
|
||||||
path=kb_file.filepath,
|
path=kb_file.filepath,
|
||||||
filename=kb_file.filename,
|
filename=kb_file.filename,
|
||||||
media_type="multipart/form-data")
|
media_type="multipart/form-data",
|
||||||
|
content_disposition_type=content_disposition_type,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
|
logger.error(msg)
|
||||||
|
return BaseResponse(code=500, msg=msg)
|
||||||
|
|
||||||
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
|
||||||
|
|
||||||
@ -190,27 +317,30 @@ async def recreate_vector_store(
|
|||||||
else:
|
else:
|
||||||
kb.create_kb()
|
kb.create_kb()
|
||||||
kb.clear_vs()
|
kb.clear_vs()
|
||||||
docs = list_files_from_folder(knowledge_base_name)
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
for i, doc in enumerate(docs):
|
kb_files = [(file, knowledge_base_name) for file in files]
|
||||||
try:
|
i = 0
|
||||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
for status, result in files2docs_in_thread(kb_files):
|
||||||
|
if status:
|
||||||
|
kb_name, file_name, docs = result
|
||||||
|
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||||
|
kb_file.splited_docs = docs
|
||||||
yield json.dumps({
|
yield json.dumps({
|
||||||
"code": 200,
|
"code": 200,
|
||||||
"msg": f"({i + 1} / {len(docs)}): {doc}",
|
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||||
"total": len(docs),
|
"total": len(files),
|
||||||
"finished": i,
|
"finished": i,
|
||||||
"doc": doc,
|
"doc": file_name,
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
if i == len(docs) - 1:
|
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||||
not_refresh_vs_cache = False
|
else:
|
||||||
else:
|
kb_name, file_name, error = result
|
||||||
not_refresh_vs_cache = True
|
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
logger.error(msg)
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
yield json.dumps({
|
yield json.dumps({
|
||||||
"code": 500,
|
"code": 500,
|
||||||
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
|
"msg": msg,
|
||||||
})
|
})
|
||||||
|
i += 1
|
||||||
|
|
||||||
return StreamingResponse(output(), media_type="text/event-stream")
|
return StreamingResponse(output(), media_type="text/event-stream")
|
||||||
|
|||||||
@ -49,6 +49,13 @@ class KBService(ABC):
|
|||||||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||||
return load_embeddings(self.embed_model, embed_device)
|
return load_embeddings(self.embed_model, embed_device)
|
||||||
|
|
||||||
|
def save_vector_store(self, vector_store=None):
|
||||||
|
'''
|
||||||
|
保存向量库,仅支持FAISS。对于其它向量库该函数不做任何操作。
|
||||||
|
减少FAISS向量库操作时的类型判断。
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
def create_kb(self):
|
def create_kb(self):
|
||||||
"""
|
"""
|
||||||
创建知识库
|
创建知识库
|
||||||
@ -82,6 +89,8 @@ class KBService(ABC):
|
|||||||
"""
|
"""
|
||||||
if docs:
|
if docs:
|
||||||
custom_docs = True
|
custom_docs = True
|
||||||
|
for doc in docs:
|
||||||
|
doc.metadata.setdefault("source", kb_file.filepath)
|
||||||
else:
|
else:
|
||||||
docs = kb_file.file2text()
|
docs = kb_file.file2text()
|
||||||
custom_docs = False
|
custom_docs = False
|
||||||
|
|||||||
@ -5,7 +5,8 @@ from configs.model_config import (
|
|||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
CACHED_VS_NUM,
|
CACHED_VS_NUM,
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
SCORE_THRESHOLD
|
SCORE_THRESHOLD,
|
||||||
|
logger,
|
||||||
)
|
)
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
@ -28,7 +29,7 @@ def load_faiss_vector_store(
|
|||||||
embeddings: Embeddings = None,
|
embeddings: Embeddings = None,
|
||||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
logger.info(f"loading vector store in '{knowledge_base_name}'.")
|
||||||
vs_path = get_vs_path(knowledge_base_name)
|
vs_path = get_vs_path(knowledge_base_name)
|
||||||
if embeddings is None:
|
if embeddings is None:
|
||||||
embeddings = load_embeddings(embed_model, embed_device)
|
embeddings = load_embeddings(embed_model, embed_device)
|
||||||
@ -57,7 +58,7 @@ def refresh_vs_cache(kb_name: str):
|
|||||||
make vector store cache refreshed when next loading
|
make vector store cache refreshed when next loading
|
||||||
"""
|
"""
|
||||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||||
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
||||||
|
|
||||||
|
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
@ -128,7 +129,7 @@ class FaissKBService(KBService):
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
vector_store = self.load_vector_store()
|
vector_store = self.load_vector_store()
|
||||||
|
|
||||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
|
||||||
if len(ids) == 0:
|
if len(ids) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,8 @@ from configs.model_config import (
|
|||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
OVERLAP_SIZE,
|
OVERLAP_SIZE,
|
||||||
ZH_TITLE_ENHANCE
|
ZH_TITLE_ENHANCE,
|
||||||
|
logger,
|
||||||
)
|
)
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import importlib
|
import importlib
|
||||||
@ -19,6 +20,7 @@ from pathlib import Path
|
|||||||
import json
|
import json
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from server.utils import run_in_thread_pool
|
from server.utils import run_in_thread_pool
|
||||||
|
import io
|
||||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||||
|
|
||||||
|
|
||||||
@ -175,12 +177,74 @@ def get_LoaderClass(file_extension):
|
|||||||
return LoaderClass
|
return LoaderClass
|
||||||
|
|
||||||
|
|
||||||
|
# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
|
||||||
|
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
|
||||||
|
'''
|
||||||
|
根据loader_name和文件路径或内容返回文档加载器。
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||||
|
document_loaders_module = importlib.import_module('document_loaders')
|
||||||
|
else:
|
||||||
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
|
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}")
|
||||||
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
|
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||||
|
|
||||||
|
if loader_name == "UnstructuredFileLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
|
||||||
|
elif loader_name == "CSVLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
|
||||||
|
elif loader_name == "JSONLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
||||||
|
elif loader_name == "CustomJSONLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, text_content=False)
|
||||||
|
elif loader_name == "UnstructuredMarkdownLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||||
|
elif loader_name == "UnstructuredHTMLLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||||
|
else:
|
||||||
|
loader = DocumentLoader(file_path_or_content)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
def make_text_splitter(
|
||||||
|
splitter_name: str = "SpacyTextSplitter",
|
||||||
|
chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
根据参数获取特定的分词器
|
||||||
|
'''
|
||||||
|
splitter_name = splitter_name or "SpacyTextSplitter"
|
||||||
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
try:
|
||||||
|
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查找分词器 {splitter_name} 时出错:{e}")
|
||||||
|
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
return text_splitter
|
||||||
|
|
||||||
class KnowledgeFile:
|
class KnowledgeFile:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str,
|
||||||
knowledge_base_name: str
|
knowledge_base_name: str
|
||||||
):
|
):
|
||||||
|
'''
|
||||||
|
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
||||||
|
'''
|
||||||
self.kb_name = knowledge_base_name
|
self.kb_name = knowledge_base_name
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.ext = os.path.splitext(filename)[-1].lower()
|
self.ext = os.path.splitext(filename)[-1].lower()
|
||||||
@ -196,65 +260,11 @@ class KnowledgeFile:
|
|||||||
|
|
||||||
def file2docs(self, refresh: bool=False):
|
def file2docs(self, refresh: bool=False):
|
||||||
if self.docs is None or refresh:
|
if self.docs is None or refresh:
|
||||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||||
try:
|
loader = get_loader(self.document_loader_name, self.filepath)
|
||||||
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
|
||||||
document_loaders_module = importlib.import_module('document_loaders')
|
|
||||||
else:
|
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
|
||||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
|
||||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
|
||||||
if self.document_loader_name == "UnstructuredFileLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
|
||||||
elif self.document_loader_name == "CSVLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
|
||||||
elif self.document_loader_name == "JSONLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
|
||||||
elif self.document_loader_name == "CustomJSONLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, text_content=False)
|
|
||||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, mode="elements")
|
|
||||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
|
||||||
loader = DocumentLoader(self.filepath, mode="elements")
|
|
||||||
else:
|
|
||||||
loader = DocumentLoader(self.filepath)
|
|
||||||
self.docs = loader.load()
|
self.docs = loader.load()
|
||||||
return self.docs
|
return self.docs
|
||||||
|
|
||||||
def make_text_splitter(
|
|
||||||
self,
|
|
||||||
chunk_size: int = CHUNK_SIZE,
|
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
if self.text_splitter_name is None:
|
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
|
||||||
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
|
|
||||||
text_splitter = TextSplitter(
|
|
||||||
pipeline="zh_core_web_sm",
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
)
|
|
||||||
self.text_splitter_name = "SpacyTextSplitter"
|
|
||||||
else:
|
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
|
||||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
|
||||||
text_splitter = TextSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
|
||||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
|
||||||
text_splitter = TextSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
)
|
|
||||||
return text_splitter
|
|
||||||
|
|
||||||
def docs2texts(
|
def docs2texts(
|
||||||
self,
|
self,
|
||||||
docs: List[Document] = None,
|
docs: List[Document] = None,
|
||||||
@ -265,10 +275,11 @@ class KnowledgeFile:
|
|||||||
text_splitter: TextSplitter = None,
|
text_splitter: TextSplitter = None,
|
||||||
):
|
):
|
||||||
docs = docs or self.file2docs(refresh=refresh)
|
docs = docs or self.file2docs(refresh=refresh)
|
||||||
|
if not docs:
|
||||||
|
return []
|
||||||
if self.ext not in [".csv"]:
|
if self.ext not in [".csv"]:
|
||||||
if text_splitter is None:
|
if text_splitter is None:
|
||||||
text_splitter = self.make_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
docs = text_splitter.split_documents(docs)
|
docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
print(f"文档切分示例:{docs[0]}")
|
print(f"文档切分示例:{docs[0]}")
|
||||||
@ -286,13 +297,18 @@ class KnowledgeFile:
|
|||||||
text_splitter: TextSplitter = None,
|
text_splitter: TextSplitter = None,
|
||||||
):
|
):
|
||||||
if self.splited_docs is None or refresh:
|
if self.splited_docs is None or refresh:
|
||||||
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
|
docs = self.file2docs()
|
||||||
|
self.splited_docs = self.docs2texts(docs=docs,
|
||||||
|
using_zh_title_enhance=using_zh_title_enhance,
|
||||||
refresh=refresh,
|
refresh=refresh,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
text_splitter=text_splitter)
|
text_splitter=text_splitter)
|
||||||
return self.splited_docs
|
return self.splited_docs
|
||||||
|
|
||||||
|
def file_exist(self):
|
||||||
|
return os.path.isfile(self.filepath)
|
||||||
|
|
||||||
def get_mtime(self):
|
def get_mtime(self):
|
||||||
return os.path.getmtime(self.filepath)
|
return os.path.getmtime(self.filepath)
|
||||||
|
|
||||||
@ -301,18 +317,21 @@ class KnowledgeFile:
|
|||||||
|
|
||||||
|
|
||||||
def files2docs_in_thread(
|
def files2docs_in_thread(
|
||||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name)
|
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
||||||
pool: ThreadPoolExecutor = None,
|
pool: ThreadPoolExecutor = None,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
'''
|
'''
|
||||||
利用多线程批量将磁盘文件转化成langchain Document.
|
利用多线程批量将磁盘文件转化成langchain Document.
|
||||||
生成器返回值为{(kb_name, file_name): docs}
|
如果传入参数是Tuple,形式为(filename, kb_name)
|
||||||
|
生成器返回值为 status, (kb_name, file_name, docs | error)
|
||||||
'''
|
'''
|
||||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||||
try:
|
try:
|
||||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False, e
|
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
|
||||||
|
logger.error(msg)
|
||||||
|
return False, (file.kb_name, file.filename, msg)
|
||||||
|
|
||||||
kwargs_list = []
|
kwargs_list = []
|
||||||
for i, file in enumerate(files):
|
for i, file in enumerate(files):
|
||||||
|
|||||||
431
server/knowledge_base/utils.py.bak
Normal file
431
server/knowledge_base/utils.py.bak
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
import os
|
||||||
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
|
from configs.model_config import (
|
||||||
|
embedding_model_dict,
|
||||||
|
KB_ROOT_PATH,
|
||||||
|
CHUNK_SIZE,
|
||||||
|
OVERLAP_SIZE,
|
||||||
|
ZH_TITLE_ENHANCE,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
from functools import lru_cache
|
||||||
|
import importlib
|
||||||
|
from text_splitter import zh_title_enhance
|
||||||
|
import langchain.document_loaders
|
||||||
|
import document_loaders
|
||||||
|
import unstructured.partition
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.text_splitter import TextSplitter
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from server.utils import run_in_thread_pool
|
||||||
|
import io
|
||||||
|
import builtins
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||||
|
|
||||||
|
|
||||||
|
# make HuggingFaceEmbeddings hashable
|
||||||
|
def _embeddings_hash(self):
|
||||||
|
if isinstance(self, HuggingFaceEmbeddings):
|
||||||
|
return hash(self.model_name)
|
||||||
|
elif isinstance(self, HuggingFaceBgeEmbeddings):
|
||||||
|
return hash(self.model_name)
|
||||||
|
elif isinstance(self, OpenAIEmbeddings):
|
||||||
|
return hash(self.model)
|
||||||
|
|
||||||
|
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||||
|
OpenAIEmbeddings.__hash__ = _embeddings_hash
|
||||||
|
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
|
||||||
|
|
||||||
|
|
||||||
|
# patch langchain.document_loaders和项目自定义document_loaders,替换其中的open函数。
|
||||||
|
# 使其支持对str,bytes,io.StringIO,io.BytesIO进行向量化
|
||||||
|
def _new_open(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], *args, **kw):
|
||||||
|
if isinstance(content, (io.StringIO, io.BytesIO)):
|
||||||
|
return content
|
||||||
|
if isinstance(content, str):
|
||||||
|
if os.path.isfile(content):
|
||||||
|
return builtins.open(content, *args, **kw)
|
||||||
|
else:
|
||||||
|
return io.StringIO(content)
|
||||||
|
if isinstance(content, bytes):
|
||||||
|
return io.BytesIO(bytes)
|
||||||
|
if isinstance(content, Path):
|
||||||
|
return Path.open(*args, **kw)
|
||||||
|
return open(content, *args, **kw)
|
||||||
|
|
||||||
|
for module in [langchain.document_loaders, document_loaders]:
|
||||||
|
for k, v in module.__dict__.items():
|
||||||
|
if type(v) == type(langchain.document_loaders):
|
||||||
|
v.open = _new_open
|
||||||
|
|
||||||
|
# path unstructured 使其在处理非磁盘文件时不会出错
|
||||||
|
def _new_get_last_modified_date(filename: str) -> Union[str, None]:
|
||||||
|
try:
|
||||||
|
modify_date = datetime.fromtimestamp(os.path.getmtime(filename))
|
||||||
|
return modify_date.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for k, v in unstructured.partition.__dict__.items():
|
||||||
|
if type(v) == type(unstructured.partition):
|
||||||
|
v.open = _new_open
|
||||||
|
v.get_last_modified_date = _new_get_last_modified_date
|
||||||
|
|
||||||
|
|
||||||
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
|
# 检查是否包含预期外的字符或路径攻击关键字
|
||||||
|
if "../" in knowledge_base_id:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_kb_path(knowledge_base_name: str):
|
||||||
|
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_doc_path(knowledge_base_name: str):
|
||||||
|
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||||
|
|
||||||
|
|
||||||
|
def get_vs_path(knowledge_base_name: str):
|
||||||
|
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||||
|
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||||
|
|
||||||
|
|
||||||
|
def list_kbs_from_folder():
|
||||||
|
return [f for f in os.listdir(KB_ROOT_PATH)
|
||||||
|
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
|
||||||
|
|
||||||
|
|
||||||
|
def list_files_from_folder(kb_name: str):
|
||||||
|
doc_path = get_doc_path(kb_name)
|
||||||
|
return [file for file in os.listdir(doc_path)
|
||||||
|
if os.path.isfile(os.path.join(doc_path, file))]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(1)
|
||||||
|
def load_embeddings(model: str, device: str):
|
||||||
|
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||||
|
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
|
||||||
|
elif 'bge-' in model:
|
||||||
|
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
|
||||||
|
model_kwargs={'device': device},
|
||||||
|
query_instruction="为这个句子生成表示以用于检索相关文章:")
|
||||||
|
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||||
|
embeddings.query_instruction = ""
|
||||||
|
else:
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||||
|
"UnstructuredMarkdownLoader": ['.md'],
|
||||||
|
"CustomJSONLoader": [".json"],
|
||||||
|
"CSVLoader": [".csv"],
|
||||||
|
"RapidOCRPDFLoader": [".pdf"],
|
||||||
|
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||||
|
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||||
|
'.rtf', '.txt', '.xml',
|
||||||
|
'.doc', '.docx', '.epub', '.odt',
|
||||||
|
'.ppt', '.pptx', '.tsv'], # '.xlsx'
|
||||||
|
}
|
||||||
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||||
|
'''
|
||||||
|
langchain的JSONLoader需要jq,在win上使用不便,进行替代。
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
content_key: Optional[str] = None,
|
||||||
|
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||||
|
text_content: bool = True,
|
||||||
|
json_lines: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize the JSONLoader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
|
||||||
|
content_key (str): The key to use to extract the content from the JSON if
|
||||||
|
results to a list of objects (dict).
|
||||||
|
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
|
||||||
|
object extracted by the jq_schema and the default metadata and returns
|
||||||
|
a dict of the updated metadata.
|
||||||
|
text_content (bool): Boolean flag to indicate whether the content is in
|
||||||
|
string format, default to True.
|
||||||
|
json_lines (bool): Boolean flag to indicate whether the input is in
|
||||||
|
JSON Lines format.
|
||||||
|
"""
|
||||||
|
self.file_path = Path(file_path).resolve()
|
||||||
|
self._content_key = content_key
|
||||||
|
self._metadata_func = metadata_func
|
||||||
|
self._text_content = text_content
|
||||||
|
self._json_lines = json_lines
|
||||||
|
|
||||||
|
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
|
||||||
|
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""Load and return documents from the JSON file."""
|
||||||
|
docs: List[Document] = []
|
||||||
|
if self._json_lines:
|
||||||
|
with self.file_path.open(encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
self._parse(line, docs)
|
||||||
|
else:
|
||||||
|
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def _parse(self, content: str, docs: List[Document]) -> None:
|
||||||
|
"""Convert given content to documents."""
|
||||||
|
data = json.loads(content)
|
||||||
|
|
||||||
|
# Perform some validation
|
||||||
|
# This is not a perfect validation, but it should catch most cases
|
||||||
|
# and prevent the user from getting a cryptic error later on.
|
||||||
|
if self._content_key is not None:
|
||||||
|
self._validate_content_key(data)
|
||||||
|
|
||||||
|
for i, sample in enumerate(data, len(docs) + 1):
|
||||||
|
metadata = dict(
|
||||||
|
source=str(self.file_path),
|
||||||
|
seq_num=i,
|
||||||
|
)
|
||||||
|
text = self._get_text(sample=sample, metadata=metadata)
|
||||||
|
docs.append(Document(page_content=text, metadata=metadata))
|
||||||
|
|
||||||
|
|
||||||
|
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
|
||||||
|
|
||||||
|
|
||||||
|
def get_LoaderClass(file_extension):
|
||||||
|
for LoaderClass, extensions in LOADER_DICT.items():
|
||||||
|
if file_extension in extensions:
|
||||||
|
return LoaderClass
|
||||||
|
|
||||||
|
|
||||||
|
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
|
||||||
|
'''
|
||||||
|
根据loader_name和文件路径或内容返回文档加载器。
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||||
|
document_loaders_module = importlib.import_module('document_loaders')
|
||||||
|
else:
|
||||||
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
|
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
|
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||||
|
|
||||||
|
if loader_name == "UnstructuredFileLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
|
||||||
|
elif loader_name == "CSVLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
|
||||||
|
elif loader_name == "JSONLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
|
||||||
|
elif loader_name == "CustomJSONLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, text_content=False)
|
||||||
|
elif loader_name == "UnstructuredMarkdownLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||||
|
elif loader_name == "UnstructuredHTMLLoader":
|
||||||
|
loader = DocumentLoader(file_path_or_content, mode="elements")
|
||||||
|
else:
|
||||||
|
loader = DocumentLoader(file_path_or_content)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
def make_text_splitter(
|
||||||
|
splitter_name: str = "SpacyTextSplitter",
|
||||||
|
chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
根据参数获取特定的分词器
|
||||||
|
'''
|
||||||
|
splitter_name = splitter_name or "SpacyTextSplitter"
|
||||||
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
try:
|
||||||
|
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
return text_splitter
|
||||||
|
|
||||||
|
|
||||||
|
def content_to_docs(content: Union[str, bytes, io.StringIO, io.BytesIO, Path], ext: str = ".md") -> List[Document]:
|
||||||
|
'''
|
||||||
|
将磁盘文件、文本、字节、内存文件等转化成Document
|
||||||
|
'''
|
||||||
|
if not ext.startswith("."):
|
||||||
|
ext = "." + ext
|
||||||
|
ext = ext.lower()
|
||||||
|
if ext not in SUPPORTED_EXTS:
|
||||||
|
raise ValueError(f"暂未支持的文件格式 {ext}")
|
||||||
|
|
||||||
|
loader_name = get_LoaderClass(ext)
|
||||||
|
loader = get_loader(loader_name=loader_name, file_path_or_content=content)
|
||||||
|
return loader.load()
|
||||||
|
|
||||||
|
|
||||||
|
def split_docs(
|
||||||
|
docs: List[Document],
|
||||||
|
splitter_name: str = "SpacyTextSplitter",
|
||||||
|
chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
) -> List[Document]:
|
||||||
|
text_splitter = make_text_splitter(splitter_name=splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
|
return text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeFile:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
knowledge_base_name: str
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
||||||
|
'''
|
||||||
|
self.kb_name = knowledge_base_name
|
||||||
|
self.filename = filename
|
||||||
|
self.ext = os.path.splitext(filename)[-1].lower()
|
||||||
|
if self.ext not in SUPPORTED_EXTS:
|
||||||
|
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||||
|
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||||
|
self.docs = None
|
||||||
|
self.splited_docs = None
|
||||||
|
self.document_loader_name = get_LoaderClass(self.ext)
|
||||||
|
|
||||||
|
# TODO: 增加依据文件格式匹配text_splitter
|
||||||
|
self.text_splitter_name = None
|
||||||
|
|
||||||
|
def file2docs(self, refresh: bool=False):
|
||||||
|
if self.docs is None or refresh:
|
||||||
|
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||||
|
loader = get_loader(self.document_loader_name, self.filepath)
|
||||||
|
self.docs = loader.load()
|
||||||
|
return self.docs
|
||||||
|
|
||||||
|
def docs2texts(
|
||||||
|
self,
|
||||||
|
docs: List[Document] = None,
|
||||||
|
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
|
refresh: bool = False,
|
||||||
|
chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
text_splitter: TextSplitter = None,
|
||||||
|
):
|
||||||
|
docs = docs or self.file2docs(refresh=refresh)
|
||||||
|
|
||||||
|
if self.ext not in [".csv"]:
|
||||||
|
if text_splitter is None:
|
||||||
|
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
|
docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
print(f"文档切分示例:{docs[0]}")
|
||||||
|
if using_zh_title_enhance:
|
||||||
|
docs = zh_title_enhance(docs)
|
||||||
|
self.splited_docs = docs
|
||||||
|
return self.splited_docs
|
||||||
|
|
||||||
|
def file2text(
|
||||||
|
self,
|
||||||
|
using_zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
|
refresh: bool = False,
|
||||||
|
chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
|
text_splitter: TextSplitter = None,
|
||||||
|
):
|
||||||
|
if self.splited_docs is None or refresh:
|
||||||
|
self.splited_docs = self.docs2texts(using_zh_title_enhance=using_zh_title_enhance,
|
||||||
|
refresh=refresh,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
text_splitter=text_splitter)
|
||||||
|
return self.splited_docs
|
||||||
|
|
||||||
|
def file_exist(self):
|
||||||
|
return os.path.isfile(self.filepath)
|
||||||
|
|
||||||
|
def get_mtime(self):
|
||||||
|
return os.path.getmtime(self.filepath)
|
||||||
|
|
||||||
|
def get_size(self):
|
||||||
|
return os.path.getsize(self.filepath)
|
||||||
|
|
||||||
|
|
||||||
|
def files2docs_in_thread(
|
||||||
|
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], # 如果是Tuple,形式为(filename, kb_name)
|
||||||
|
pool: ThreadPoolExecutor = None,
|
||||||
|
) -> Generator:
|
||||||
|
'''
|
||||||
|
利用多线程批量将磁盘文件转化成langchain Document.
|
||||||
|
生成器返回值为 status, (kb_name, file_name, docs)
|
||||||
|
'''
|
||||||
|
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||||
|
try:
|
||||||
|
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
return False, e
|
||||||
|
|
||||||
|
kwargs_list = []
|
||||||
|
for i, file in enumerate(files):
|
||||||
|
kwargs = {}
|
||||||
|
if isinstance(file, tuple) and len(file) >= 2:
|
||||||
|
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||||
|
elif isinstance(file, dict):
|
||||||
|
filename = file.pop("filename")
|
||||||
|
kb_name = file.pop("kb_name")
|
||||||
|
kwargs = file
|
||||||
|
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||||
|
kwargs["file"] = file
|
||||||
|
kwargs_list.append(kwargs)
|
||||||
|
|
||||||
|
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
# kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
|
||||||
|
# # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
||||||
|
# docs = kb_file.file2docs()
|
||||||
|
# pprint(docs[-1])
|
||||||
|
|
||||||
|
# docs = kb_file.file2text()
|
||||||
|
# pprint(docs[-1])
|
||||||
|
|
||||||
|
docs = content_to_docs("""
|
||||||
|
## this is a title
|
||||||
|
|
||||||
|
## another title
|
||||||
|
|
||||||
|
how are you
|
||||||
|
this a wonderful day.
|
||||||
|
""", "txt")
|
||||||
|
pprint(docs)
|
||||||
|
pprint(split_docs(docs, chunk_size=10, chunk_overlap=0))
|
||||||
@ -7,17 +7,21 @@ root_path = Path(__file__).parent.parent.parent
|
|||||||
sys.path.append(str(root_path))
|
sys.path.append(str(root_path))
|
||||||
from server.utils import api_address
|
from server.utils import api_address
|
||||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||||
from server.knowledge_base.utils import get_kb_path
|
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||||
|
from webui_pages.utils import ApiRequest
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
api_base_url = api_address()
|
api_base_url = api_address()
|
||||||
|
api = ApiRequest(api_base_url)
|
||||||
|
|
||||||
|
|
||||||
kb = "kb_for_api_test"
|
kb = "kb_for_api_test"
|
||||||
test_files = {
|
test_files = {
|
||||||
|
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||||
"README.MD": str(root_path / "README.MD"),
|
"README.MD": str(root_path / "README.MD"),
|
||||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
|
"test.txt": get_file_path("samples", "test.txt"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
|
|||||||
assert kb in data["data"]
|
assert kb in data["data"]
|
||||||
|
|
||||||
|
|
||||||
def test_upload_doc(api="/knowledge_base/upload_doc"):
|
def test_upload_docs(api="/knowledge_base/upload_docs"):
|
||||||
url = api_base_url + api
|
url = api_base_url + api
|
||||||
for name, path in test_files.items():
|
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||||
print(f"\n上传知识文件: {name}")
|
|
||||||
data = {"knowledge_base_name": kb, "override": True}
|
|
||||||
files = {"file": (name, open(path, "rb"))}
|
|
||||||
r = requests.post(url, data=data, files=files)
|
|
||||||
data = r.json()
|
|
||||||
pprint(data)
|
|
||||||
assert data["code"] == 200
|
|
||||||
assert data["msg"] == f"成功上传文件 {name}"
|
|
||||||
|
|
||||||
for name, path in test_files.items():
|
print(f"\n上传知识文件")
|
||||||
print(f"\n尝试重新上传知识文件: {name}, 不覆盖")
|
data = {"knowledge_base_name": kb, "override": True}
|
||||||
data = {"knowledge_base_name": kb, "override": False}
|
r = requests.post(url, data=data, files=files)
|
||||||
files = {"file": (name, open(path, "rb"))}
|
data = r.json()
|
||||||
r = requests.post(url, data=data, files=files)
|
pprint(data)
|
||||||
data = r.json()
|
assert data["code"] == 200
|
||||||
pprint(data)
|
assert len(data["data"]["failed_files"]) == 0
|
||||||
assert data["code"] == 404
|
|
||||||
assert data["msg"] == f"文件 {name} 已存在。"
|
|
||||||
|
|
||||||
for name, path in test_files.items():
|
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||||
print(f"\n尝试重新上传知识文件: {name}, 覆盖")
|
data = {"knowledge_base_name": kb, "override": False}
|
||||||
data = {"knowledge_base_name": kb, "override": True}
|
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||||
files = {"file": (name, open(path, "rb"))}
|
r = requests.post(url, data=data, files=files)
|
||||||
r = requests.post(url, data=data, files=files)
|
data = r.json()
|
||||||
data = r.json()
|
pprint(data)
|
||||||
pprint(data)
|
assert data["code"] == 200
|
||||||
assert data["code"] == 200
|
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||||
assert data["msg"] == f"成功上传文件 {name}"
|
|
||||||
|
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||||
|
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||||
|
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||||
|
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
|
||||||
|
r = requests.post(url, data=data, files=files)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert len(data["data"]["failed_files"]) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_list_files(api="/knowledge_base/list_files"):
|
def test_list_files(api="/knowledge_base/list_files"):
|
||||||
@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"):
|
|||||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
|
|
||||||
def test_update_doc(api="/knowledge_base/update_doc"):
|
def test_update_docs(api="/knowledge_base/update_docs"):
|
||||||
url = api_base_url + api
|
url = api_base_url + api
|
||||||
for name, path in test_files.items():
|
|
||||||
print(f"\n更新知识文件: {name}")
|
print(f"\n更新知识文件")
|
||||||
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
|
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||||
data = r.json()
|
data = r.json()
|
||||||
pprint(data)
|
pprint(data)
|
||||||
assert data["code"] == 200
|
assert data["code"] == 200
|
||||||
assert data["msg"] == f"成功更新文件 {name}"
|
assert len(data["data"]["failed_files"]) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_delete_doc(api="/knowledge_base/delete_doc"):
|
def test_delete_docs(api="/knowledge_base/delete_docs"):
|
||||||
url = api_base_url + api
|
url = api_base_url + api
|
||||||
for name, path in test_files.items():
|
|
||||||
print(f"\n删除知识文件: {name}")
|
print(f"\n删除知识文件")
|
||||||
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
|
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
|
||||||
data = r.json()
|
data = r.json()
|
||||||
pprint(data)
|
pprint(data)
|
||||||
assert data["code"] == 200
|
assert data["code"] == 200
|
||||||
assert data["msg"] == f"{name} 文件删除成功"
|
assert len(data["data"]["failed_files"]) == 0
|
||||||
|
|
||||||
url = api_base_url + "/knowledge_base/search_docs"
|
url = api_base_url + "/knowledge_base/search_docs"
|
||||||
query = "介绍一下langchain-chatchat项目"
|
query = "介绍一下langchain-chatchat项目"
|
||||||
|
|||||||
@ -21,9 +21,7 @@ from fastapi.responses import StreamingResponse
|
|||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
|
||||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
|
||||||
from server.utils import run_async, iter_over_async, set_httpx_timeout
|
|
||||||
|
|
||||||
from configs.model_config import NLTK_DATA_PATH
|
from configs.model_config import NLTK_DATA_PATH
|
||||||
import nltk
|
import nltk
|
||||||
@ -43,7 +41,7 @@ class ApiRequest:
|
|||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str = "http://127.0.0.1:7861",
|
base_url: str = api_address(),
|
||||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||||
no_remote_api: bool = False, # call api view function directly
|
no_remote_api: bool = False, # call api view function directly
|
||||||
):
|
):
|
||||||
@ -78,7 +76,7 @@ class ApiRequest:
|
|||||||
else:
|
else:
|
||||||
return httpx.get(url, params=params, **kwargs)
|
return httpx.get(url, params=params, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(f"error when get {url}: {e}")
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
async def aget(
|
async def aget(
|
||||||
@ -99,7 +97,7 @@ class ApiRequest:
|
|||||||
else:
|
else:
|
||||||
return await client.get(url, params=params, **kwargs)
|
return await client.get(url, params=params, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(f"error when aget {url}: {e}")
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def post(
|
def post(
|
||||||
@ -121,7 +119,7 @@ class ApiRequest:
|
|||||||
else:
|
else:
|
||||||
return httpx.post(url, data=data, json=json, **kwargs)
|
return httpx.post(url, data=data, json=json, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(f"error when post {url}: {e}")
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
async def apost(
|
async def apost(
|
||||||
@ -143,7 +141,7 @@ class ApiRequest:
|
|||||||
else:
|
else:
|
||||||
return await client.post(url, data=data, json=json, **kwargs)
|
return await client.post(url, data=data, json=json, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(f"error when apost {url}: {e}")
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
@ -164,7 +162,7 @@ class ApiRequest:
|
|||||||
else:
|
else:
|
||||||
return httpx.delete(url, data=data, json=json, **kwargs)
|
return httpx.delete(url, data=data, json=json, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(f"error when delete {url}: {e}")
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
async def adelete(
|
async def adelete(
|
||||||
@ -186,7 +184,7 @@ class ApiRequest:
|
|||||||
else:
|
else:
|
||||||
return await client.delete(url, data=data, json=json, **kwargs)
|
return await client.delete(url, data=data, json=json, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(f"error when adelete {url}: {e}")
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
||||||
@ -205,7 +203,7 @@ class ApiRequest:
|
|||||||
elif chunk.strip():
|
elif chunk.strip():
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(f"error when run fastapi router: {e}")
|
||||||
|
|
||||||
def _httpx_stream2generator(
|
def _httpx_stream2generator(
|
||||||
self,
|
self,
|
||||||
@ -231,18 +229,18 @@ class ApiRequest:
|
|||||||
print(chunk, end="", flush=True)
|
print(chunk, end="", flush=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||||||
|
logger.error(msg)
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
logger.error(e)
|
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
except httpx.ReadTimeout as e:
|
except httpx.ReadTimeout as e:
|
||||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')"
|
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')。({e})"
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
logger.error(e)
|
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
msg = f"API通信遇到错误:{e}"
|
||||||
yield {"code": 500, "msg": str(e)}
|
logger.error(msg)
|
||||||
|
yield {"code": 500, "msg": msg}
|
||||||
|
|
||||||
# 对话相关操作
|
# 对话相关操作
|
||||||
|
|
||||||
@ -413,8 +411,9 @@ class ApiRequest:
|
|||||||
try:
|
try:
|
||||||
return response.json()
|
return response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
msg = "API未能返回正确的JSON。" + (errorMsg or str(e))
|
||||||
return {"code": 500, "msg": errorMsg or str(e)}
|
logger.error(msg)
|
||||||
|
return {"code": 500, "msg": msg}
|
||||||
|
|
||||||
def list_knowledge_bases(
|
def list_knowledge_bases(
|
||||||
self,
|
self,
|
||||||
@ -510,12 +509,13 @@ class ApiRequest:
|
|||||||
data = self._check_httpx_json_response(response)
|
data = self._check_httpx_json_response(response)
|
||||||
return data.get("data", [])
|
return data.get("data", [])
|
||||||
|
|
||||||
def upload_kb_doc(
|
def upload_kb_docs(
|
||||||
self,
|
self,
|
||||||
file: Union[str, Path, bytes],
|
files: List[Union[str, Path, bytes]],
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
filename: str = None,
|
|
||||||
override: bool = False,
|
override: bool = False,
|
||||||
|
to_vector_store: bool = True,
|
||||||
|
docs: List[Dict] = [],
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user