diff --git a/.gitignore b/.gitignore index b5918eeb..a7ef90f8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,7 @@ logs .idea/ __pycache__/ -knowledge_base/ +/knowledge_base/ configs/*.py +.vscode/ +.pytest_cache/ diff --git a/README.md b/README.md index c5c2d687..577e7f7b 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,8 @@ * [2. 下载模型至本地](README.md#2.-下载模型至本地) * [3. 设置配置项](README.md#3.-设置配置项) * [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移) - * [5. 启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI) - * [6. 一键启动](README.md#6.-一键启动) + * [5. 一键启动API服务或WebUI服务](README.md#6.-一键启动) + * [6. 分步启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI) * [常见问题](README.md#常见问题) * [路线图](README.md#路线图) * [项目交流群](README.md#项目交流群) @@ -226,9 +226,93 @@ embedding_model_dict = { $ python init_database.py --recreate-vs ``` -### 5. 启动 API 服务或 Web UI +### 5. 一键启动API 服务或 Web UI -#### 5.1 启动 LLM 服务 +#### 5.1 启动命令 + +一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码: + +```shell +$ python startup.py -a +``` + +并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。 + +可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`, +`-m (或--model-worker)`, `--api`, `--webui`,其中: + +- `--all-webui` 为一键启动 WebUI 所有依赖服务; +- `--all-api` 为一键启动 API 所有依赖服务; +- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务; +- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务; +- 其他为单独服务启动选项。 + +#### 5.2 启动非默认模型 + +若想指定非默认模型,需要用 `--model-name` 选项,示例: + +```shell +$ python startup.py --all-webui --model-name Qwen-7B-Chat +``` + +更多信息可通过 `python startup.py -h`查看。 + +#### 5.3 多卡加载 + +项目支持多卡加载,需在 startup.py 中的 create_model_worker_app 函数中,修改如下三个参数: + +```python +gpus=None, +num_gpus=1, +max_gpu_memory="20GiB" +``` + +其中,`gpus` 控制使用的显卡的ID,例如 "0,1"; + +`num_gpus` 控制使用的卡数; + +`max_gpu_memory` 控制每个卡使用的显存容量。 + +注1:server_config.py的FSCHAT_MODEL_WORKERS字典中也增加了相关配置,如有需要也可通过修改FSCHAT_MODEL_WORKERS字典中对应参数实现多卡加载。 + +注2:少数情况下,gpus参数会不生效,此时需要通过设置环境变量CUDA_VISIBLE_DEVICES来指定torch可见的gpu,示例代码: + +```shell +CUDA_VISIBLE_DEVICES=0,1 python startup.py -a +``` + +#### 5.4 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等) + +本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含.bin 格式的 PEFT 权重,peft路径在startup.py中create_model_worker_app函数的args.model_names中指定,并开启环境变量PEFT_SHARE_BASE_WEIGHTS=true参数。 + +注:如果上述方式启动失败,则需要以标准的fastchat服务启动方式分步启动,分步启动步骤参考第六节,PEFT加载详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822), + +#### **5.5 注意事项:** + +**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)** + +**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。** + +**3. 在Linux上使用ctrl+C退出可能会由于linux的多进程机制导致multiprocessing遗留孤儿进程,可通过shutdown_all.sh进行退出** + +#### 5.6 启动界面示例: + +1. FastAPI docs 界面 + +![](img/fastapi_docs_020_0.png) + +2. webui启动界面示例: + +- Web UI 对话界面: + ![img](img/webui_0813_0.png) +- Web UI 知识库管理页面: + ![](img/webui_0813_1.png) + +### 6 分步启动 API 服务或 Web UI + +注意:如使用了一键启动方式,可忽略本节。 + +#### 6.1 启动 LLM 服务 如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种: @@ -240,7 +324,7 @@ embedding_model_dict = { 如果启动在线的API服务(如 OPENAI 的 API 接口),则无需启动 LLM 服务,即 5.1 小节的任何命令均无需启动。 -##### 5.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务 +##### 6.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务 在项目根目录下,执行 [server/llm_api.py](server/llm_api.py) 脚本启动 **LLM 模型**服务: @@ -248,7 +332,7 @@ embedding_model_dict = { $ python server/llm_api.py ``` -项目支持多卡加载,需在 llm_api.py 中修改 create_model_worker_app 函数中,修改如下三个参数: +项目支持多卡加载,需在 llm_api.py 中的 create_model_worker_app 函数中,修改如下三个参数: ```python gpus=None, @@ -262,11 +346,11 @@ max_gpu_memory="20GiB" `max_gpu_memory` 控制每个卡使用的显存容量。 -##### 5.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务 +##### 6.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务 ⚠️ **注意:** -**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;** +**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wsl;** **2.加载非默认模型需要用命令行参数--model-path-address指定模型,不会读取model_config.py配置;** @@ -302,14 +386,14 @@ $ python server/llm_api_shutdown.py --serve all 亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`] -##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等) +##### 6.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等) 本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含 model.bin 格式的 PEFT 权重。 详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822) ![image](https://github.com/chatchat-space/Langchain-Chatchat/assets/22924096/4e056c1c-5c4b-4865-a1af-859cd58a625d) -#### 5.2 启动 API 服务 +#### 6.2 启动 API 服务 本地部署情况下,按照 [5.1 节](README.md#5.1-启动-LLM-服务)**启动 LLM 服务后**,再执行 [server/api.py](server/api.py) 脚本启动 **API** 服务; @@ -327,7 +411,7 @@ $ python server/api.py ![](img/fastapi_docs_020_0.png) -#### 5.3 启动 Web UI 服务 +#### 6.3 启动 Web UI 服务 按照 [5.2 节](README.md#5.2-启动-API-服务)**启动 API 服务后**,执行 [webui.py](webui.py) 启动 **Web UI** 服务(默认使用端口 `8501`) @@ -356,41 +440,6 @@ $ streamlit run webui.py --server.port 666 --- -### 6. 一键启动 - -更新一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码: - -```shell -$ python startup.py -a -``` - -并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。 - -可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`, -`-m (或--model-worker)`, `--api`, `--webui`,其中: - -- `--all-webui` 为一键启动 WebUI 所有依赖服务; -- `--all-api` 为一键启动 API 所有依赖服务; -- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务; -- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务; -- 其他为单独服务启动选项。 - -若想指定非默认模型,需要用 `--model-name` 选项,示例: - -```shell -$ python startup.py --all-webui --model-name Qwen-7B-Chat -``` - -更多信息可通过 `python startup.py -h`查看。 - -**注意:** - -**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)** - -**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。** - -**3. 在Linux上使用ctrl+C退出可能会由于linux的多进程机制导致multiprocessing遗留孤儿进程,可通过shutdown_all.sh进行退出** - ## 常见问题 参见 [常见问题](docs/FAQ.md)。 @@ -433,6 +482,6 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat ## 项目交流群 -二维码 +二维码 🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 5b2574e9..1acaf1c1 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -1,6 +1,5 @@ import os import logging -import torch # 日志格式 LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() @@ -32,9 +31,8 @@ embedding_model_dict = { # 选用的 Embedding 名称 EMBEDDING_MODEL = "m3e-base" -# Embedding 模型运行设备 -EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - +# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 +EMBEDDING_DEVICE = "auto" llm_model_dict = { "chatglm-6b": { @@ -69,20 +67,32 @@ llm_model_dict = { # 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.._completion_with_retry in # 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI. # 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置 + # 比如: "openai_proxy": 'http://127.0.0.1:4780' "gpt-3.5-turbo": { - "local_model_path": "gpt-3.5-turbo", "api_base_url": "https://api.openai.com/v1", "api_key": os.environ.get("OPENAI_API_KEY"), "openai_proxy": os.environ.get("OPENAI_PROXY") }, + # 线上模型。当前支持智谱AI。 + # 如果没有设置有效的local_model_path,则认为是在线模型API。 + # 请在server_config中为每个在线API设置不同的端口 + # 具体注册及api key获取请前往 http://open.bigmodel.cn + "chatglm-api": { + "api_base_url": "http://127.0.0.1:8888/v1", + "api_key": os.environ.get("ZHIPUAI_API_KEY"), + "provider": "ChatGLMWorker", + "version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro" + }, } - # LLM 名称 LLM_MODEL = "chatglm2-6b" -# LLM 运行设备 -LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +# 历史对话轮数 +HISTORY_LEN = 3 + +# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 +LLM_DEVICE = "auto" # 日志存储路径 LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") @@ -164,4 +174,4 @@ BING_SUBSCRIPTION_KEY = "" # 是否开启中文标题加强,以及标题增强的相关配置 # 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 -ZH_TITLE_ENHANCE = False \ No newline at end of file +ZH_TITLE_ENHANCE = False diff --git a/configs/server_config.py.example b/configs/server_config.py.example index b0f37bf4..ad731e37 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -1,4 +1,8 @@ -from .model_config import LLM_MODEL, LLM_DEVICE +from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE +import httpx + +# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 +HTTPX_DEFAULT_TIMEOUT = 300.0 # API 是否开启跨域,默认为False,如果需要开启,请设置为True # is open cross domain @@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = { # 这些模型必须是在model_config.llm_model_dict中正确配置的。 # 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL FSCHAT_MODEL_WORKERS = { - LLM_MODEL: { + # 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。 + "default": { "host": DEFAULT_BIND_HOST, "port": 20002, "device": LLM_DEVICE, - # todo: 多卡加载需要配置的参数 - "gpus": None, # 使用的GPU,以str的格式指定,如"0,1" - "num_gpus": 1, # 使用GPU的数量 - # 以下为非常用参数,可根据需要配置 + + # 多卡加载需要配置的参数 + # "gpus": None, # 使用的GPU,以str的格式指定,如"0,1" + # "num_gpus": 1, # 使用GPU的数量 # "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存 + + # 以下为非常用参数,可根据需要配置 # "load_8bit": False, # 开启8bit量化 # "cpu_offloading": None, # "gptq_ckpt": None, @@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = { # "stream_interval": 2, # "no_register": False, }, + "baichuan-7b": { # 使用default中的IP和端口 + "device": "cpu", + }, + "chatglm-api": { # 请为每个在线API设置不同的端口 + "port": 20003, + }, } # fastchat multi model worker server FSCHAT_MULTI_MODEL_WORKERS = { - # todo + # TODO: } # fastchat controller server @@ -66,35 +79,3 @@ FSCHAT_CONTROLLER = { "port": 20001, "dispatch_method": "shortest_queue", } - - -# 以下不要更改 -def fschat_controller_address() -> str: - host = FSCHAT_CONTROLLER["host"] - port = FSCHAT_CONTROLLER["port"] - return f"http://{host}:{port}" - - -def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: - if model := FSCHAT_MODEL_WORKERS.get(model_name): - host = model["host"] - port = model["port"] - return f"http://{host}:{port}" - - -def fschat_openai_api_address() -> str: - host = FSCHAT_OPENAI_API["host"] - port = FSCHAT_OPENAI_API["port"] - return f"http://{host}:{port}" - - -def api_address() -> str: - host = API_SERVER["host"] - port = API_SERVER["port"] - return f"http://{host}:{port}" - - -def webui_address() -> str: - host = WEBUI_SERVER["host"] - port = WEBUI_SERVER["port"] - return f"http://{host}:{port}" diff --git a/document_loaders/__init__.py b/document_loaders/__init__.py new file mode 100644 index 00000000..a4d6b28d --- /dev/null +++ b/document_loaders/__init__.py @@ -0,0 +1,2 @@ +from .mypdfloader import RapidOCRPDFLoader +from .myimgloader import RapidOCRLoader \ No newline at end of file diff --git a/document_loaders/myimgloader.py b/document_loaders/myimgloader.py new file mode 100644 index 00000000..86481924 --- /dev/null +++ b/document_loaders/myimgloader.py @@ -0,0 +1,25 @@ +from typing import List +from langchain.document_loaders.unstructured import UnstructuredFileLoader + + +class RapidOCRLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def img2text(filepath): + from rapidocr_onnxruntime import RapidOCR + resp = "" + ocr = RapidOCR() + result, _ = ocr(filepath) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + return resp + + text = img2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg") + docs = loader.load() + print(docs) diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py new file mode 100644 index 00000000..71e063d6 --- /dev/null +++ b/document_loaders/mypdfloader.py @@ -0,0 +1,37 @@ +from typing import List +from langchain.document_loaders.unstructured import UnstructuredFileLoader + + +class RapidOCRPDFLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def pdf2text(filepath): + import fitz + from rapidocr_onnxruntime import RapidOCR + import numpy as np + ocr = RapidOCR() + doc = fitz.open(filepath) + resp = "" + for page in doc: + # TODO: 依据文本与图片顺序调整处理方式 + text = page.get_text("") + resp += text + "\n" + + img_list = page.get_images() + for img in img_list: + pix = fitz.Pixmap(doc, img[0]) + img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) + result, _ = ocr(img_array) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + return resp + + text = pdf2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf") + docs = loader.load() + print(docs) diff --git a/img/qr_code_50.jpg b/img/qr_code_50.jpg deleted file mode 100644 index c0ae20f6..00000000 Binary files a/img/qr_code_50.jpg and /dev/null differ diff --git a/img/qr_code_51.jpg b/img/qr_code_51.jpg deleted file mode 100644 index f0993322..00000000 Binary files a/img/qr_code_51.jpg and /dev/null differ diff --git a/img/qr_code_52.jpg b/img/qr_code_52.jpg deleted file mode 100644 index 18793d56..00000000 Binary files a/img/qr_code_52.jpg and /dev/null differ diff --git a/img/qr_code_53.jpg b/img/qr_code_53.jpg deleted file mode 100644 index 3174ccc1..00000000 Binary files a/img/qr_code_53.jpg and /dev/null differ diff --git a/img/qr_code_54.jpg b/img/qr_code_54.jpg deleted file mode 100644 index 1245a164..00000000 Binary files a/img/qr_code_54.jpg and /dev/null differ diff --git a/img/qr_code_55.jpg b/img/qr_code_55.jpg deleted file mode 100644 index 8ff046c9..00000000 Binary files a/img/qr_code_55.jpg and /dev/null differ diff --git a/img/qr_code_56.jpg b/img/qr_code_56.jpg deleted file mode 100644 index f17458d2..00000000 Binary files a/img/qr_code_56.jpg and /dev/null differ diff --git a/img/qr_code_58.jpg b/img/qr_code_58.jpg new file mode 100644 index 00000000..90c861d7 Binary files /dev/null and b/img/qr_code_58.jpg differ diff --git a/init_database.py b/init_database.py index 7fc84940..42a18c53 100644 --- a/init_database.py +++ b/init_database.py @@ -1,8 +1,9 @@ -from server.knowledge_base.migrate import create_tables, folder2db, recreate_all_vs, list_kbs_from_folder +from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path from startup import dump_server_info +from datetime import datetime if __name__ == "__main__": @@ -25,13 +26,19 @@ if __name__ == "__main__": dump_server_info() - create_tables() - print("database talbes created") + start_time = datetime.now() if args.recreate_vs: + reset_tables() + print("database talbes reseted") print("recreating all vector stores") recreate_all_vs() else: + create_tables() + print("database talbes created") print("filling kb infos to database") for kb in list_kbs_from_folder(): folder2db(kb, "fill_info_only") + + end_time = datetime.now() + print(f"总计用时: {end_time-start_time}") diff --git a/requirements.txt b/requirements.txt index e40f665a..4271f3af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,8 @@ SQLAlchemy==2.0.19 faiss-cpu accelerate spacy +PyMuPDF==1.22.5 +rapidocr_onnxruntime>=1.3.1 # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/requirements_api.txt b/requirements_api.txt index 58dbc0cd..bdecf3c7 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -16,6 +16,8 @@ faiss-cpu nltk accelerate spacy +PyMuPDF==1.22.5 +rapidocr_onnxruntime>=1.3.1 # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 diff --git a/server/api.py b/server/api.py index ecadd7cc..37954b7f 100644 --- a/server/api.py +++ b/server/api.py @@ -4,20 +4,22 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import NLTK_DATA_PATH -from configs.server_config import OPEN_CROSS_DOMAIN +from configs.model_config import LLM_MODEL, NLTK_DATA_PATH +from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT from configs import VERSION import argparse import uvicorn +from fastapi import Body from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb -from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, +from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store, search_docs, DocumentWithScore) -from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline +from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address +import httpx from typing import List nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -84,11 +86,11 @@ def create_app(): summary="删除知识库" )(delete_kb) - app.get("/knowledge_base/list_docs", + app.get("/knowledge_base/list_files", tags=["Knowledge Base Management"], response_model=ListResponse, summary="获取知识库内的文件列表" - )(list_docs) + )(list_files) app.post("/knowledge_base/search_docs", tags=["Knowledge Base Management"], @@ -123,6 +125,75 @@ def create_app(): summary="根据content中文档重建向量库,流式输出处理进度。" )(recreate_vector_store) + # LLM模型相关接口 + @app.post("/llm_model/list_models", + tags=["LLM Model Management"], + summary="列出当前已加载的模型") + def list_models( + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) + ) -> BaseResponse: + ''' + 从fastchat controller获取已加载模型列表 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post(controller_address + "/list_models") + return BaseResponse(data=r.json()["models"]) + except Exception as e: + return BaseResponse( + code=500, + data=[], + msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") + + @app.post("/llm_model/stop", + tags=["LLM Model Management"], + summary="停止指定的LLM模型(Model Worker)", + ) + def stop_llm_model( + model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) + ) -> BaseResponse: + ''' + 向fastchat controller请求停止某个LLM模型。 + 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post( + controller_address + "/release_worker", + json={"model_name": model_name}, + ) + return r.json() + except Exception as e: + return BaseResponse( + code=500, + msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}") + + @app.post("/llm_model/change", + tags=["LLM Model Management"], + summary="切换指定的LLM模型(Model Worker)", + ) + def change_llm_model( + model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]), + new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]), + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) + ): + ''' + 向fastchat controller请求切换LLM模型。 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post( + controller_address + "/release_worker", + json={"model_name": model_name, "new_model_name": new_model_name}, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model + ) + return r.json() + except Exception as e: + return BaseResponse( + code=500, + msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") + return app diff --git a/server/chat/chat.py b/server/chat/chat.py index 2e939f1d..ba23a5a1 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 {"role": "assistant", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), ): history = [History.from_data(h) for h in history] async def chat_iterator(query: str, history: List[History] = [], + model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() @@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 streaming=True, verbose=True, callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL, - openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") + openai_api_key=llm_model_dict[model_name]["api_key"], + openai_api_base=llm_model_dict[model_name]["api_base_url"], + model_name=model_name, + openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) input_msg = History(role="user", content="{{ input }}").to_msg_template(False) @@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 await task - return StreamingResponse(chat_iterator(query, history), + return StreamingResponse(chat_iterator(query, history, model_name), media_type="text/event-stream") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 27745691..69ec25dd 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), request: Request = None, ): @@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp kb: KBService, top_k: int, history: Optional[List[History]], + model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( streaming=True, verbose=True, callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL, - openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") + openai_api_key=llm_model_dict[model_name]["api_key"], + openai_api_base=llm_model_dict[model_name]["api_base_url"], + model_name=model_name, + openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) @@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp await task - return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history), + return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name), media_type="text/event-stream") diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index a7ad8074..a799c623 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -29,23 +29,22 @@ async def openai_chat(msg: OpenAiChatMsgIn): print(f"{openai.api_base=}") print(msg) - async def get_response(msg): + def get_response(msg): data = msg.dict() - data["streaming"] = True - data.pop("stream") try: response = openai.ChatCompletion.create(**data) if msg.stream: - for chunk in response.choices[0].message.content: - print(chunk) - yield chunk + for data in response: + if choices := data.choices: + if chunk := choices[0].get("delta", {}).get("content"): + print(chunk, end="", flush=True) + yield chunk else: - answer = "" - for chunk in response.choices[0].message.content: - answer += chunk - print(answer) - yield(answer) + if response.choices: + answer = response.choices[0].message.content + print(answer) + yield(answer) except Exception as e: print(type(e)) logger.error(e) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 8a2633bd..8fe7dae6 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl search_engine_name: str, top_k: int, history: Optional[List[History]], + model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( streaming=True, verbose=True, callbacks=[callback], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL, - openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") + openai_api_key=llm_model_dict[model_name]["api_key"], + openai_api_base=llm_model_dict[model_name]["api_base_url"], + model_name=model_name, + openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) docs = lookup_search_engine(query, search_engine_name, top_k) @@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl ensure_ascii=False) await task - return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), + return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name), media_type="text/event-stream") diff --git a/server/db/base.py b/server/db/base.py index 3a8529b0..1d911c05 100644 --- a/server/db/base.py +++ b/server/db/base.py @@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from configs.model_config import SQLALCHEMY_DATABASE_URI +import json -engine = create_engine(SQLALCHEMY_DATABASE_URI) + +engine = create_engine( + SQLALCHEMY_DATABASE_URI, + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), +) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() - diff --git a/server/db/models/knowledge_base_model.py b/server/db/models/knowledge_base_model.py index 37abd4ee..478bc1f3 100644 --- a/server/db/models/knowledge_base_model.py +++ b/server/db/models/knowledge_base_model.py @@ -9,9 +9,9 @@ class KnowledgeBaseModel(Base): """ __tablename__ = 'knowledge_base' id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID') - kb_name = Column(String, comment='知识库名称') - vs_type = Column(String, comment='嵌入模型类型') - embed_model = Column(String, comment='嵌入模型名称') + kb_name = Column(String(50), comment='知识库名称') + vs_type = Column(String(50), comment='向量库类型') + embed_model = Column(String(50), comment='嵌入模型名称') file_count = Column(Integer, default=0, comment='文件数量') create_time = Column(DateTime, default=func.now(), comment='创建时间') diff --git a/server/db/models/knowledge_file_model.py b/server/db/models/knowledge_file_model.py index 7fffdfbd..c5784d1a 100644 --- a/server/db/models/knowledge_file_model.py +++ b/server/db/models/knowledge_file_model.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, DateTime, func +from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func from server.db.base import Base @@ -9,13 +9,32 @@ class KnowledgeFileModel(Base): """ __tablename__ = 'knowledge_file' id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID') - file_name = Column(String, comment='文件名') - file_ext = Column(String, comment='文件扩展名') - kb_name = Column(String, comment='所属知识库名称') - document_loader_name = Column(String, comment='文档加载器名称') - text_splitter_name = Column(String, comment='文本分割器名称') + file_name = Column(String(255), comment='文件名') + file_ext = Column(String(10), comment='文件扩展名') + kb_name = Column(String(50), comment='所属知识库名称') + document_loader_name = Column(String(50), comment='文档加载器名称') + text_splitter_name = Column(String(50), comment='文本分割器名称') file_version = Column(Integer, default=1, comment='文件版本') + file_mtime = Column(Float, default=0.0, comment="文件修改时间") + file_size = Column(Integer, default=0, comment="文件大小") + custom_docs = Column(Boolean, default=False, comment="是否自定义docs") + docs_count = Column(Integer, default=0, comment="切分文档数量") create_time = Column(DateTime, default=func.now(), comment='创建时间') def __repr__(self): return f"" + + +class FileDocModel(Base): + """ + 文件-向量库文档模型 + """ + __tablename__ = 'file_doc' + id = Column(Integer, primary_key=True, autoincrement=True, comment='ID') + kb_name = Column(String(50), comment='知识库名称') + file_name = Column(String(255), comment='文件名称') + doc_id = Column(String(50), comment="向量库文档ID") + meta_data = Column(JSON, default={}) + + def __repr__(self): + return f"" diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py index 404910ff..08417a4d 100644 --- a/server/db/repository/knowledge_file_repository.py +++ b/server/db/repository/knowledge_file_repository.py @@ -1,24 +1,101 @@ from server.db.models.knowledge_base_model import KnowledgeBaseModel -from server.db.models.knowledge_file_model import KnowledgeFileModel +from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel from server.db.session import with_session from server.knowledge_base.utils import KnowledgeFile +from typing import List, Dict @with_session -def list_docs_from_db(session, kb_name): +def list_docs_from_db(session, + kb_name: str, + file_name: str = None, + metadata: Dict = {}, + ) -> List[Dict]: + ''' + 列出某知识库某文件对应的所有Document。 + 返回形式:[{"id": str, "metadata": dict}, ...] + ''' + docs = session.query(FileDocModel).filter_by(kb_name=kb_name) + if file_name: + docs = docs.filter_by(file_name=file_name) + for k, v in metadata.items(): + docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v)) + + return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()] + + +@with_session +def delete_docs_from_db(session, + kb_name: str, + file_name: str = None, + ) -> List[Dict]: + ''' + 删除某知识库某文件对应的所有Document,并返回被删除的Document。 + 返回形式:[{"id": str, "metadata": dict}, ...] + ''' + docs = list_docs_from_db(kb_name=kb_name, file_name=file_name) + query = session.query(FileDocModel).filter_by(kb_name=kb_name) + if file_name: + query = query.filter_by(file_name=file_name) + query.delete() + session.commit() + return docs + + +@with_session +def add_docs_to_db(session, + kb_name: str, + file_name: str, + doc_infos: List[Dict]): + ''' + 将某知识库某文件对应的所有Document信息添加到数据库。 + doc_infos形式:[{"id": str, "metadata": dict}, ...] + ''' + for d in doc_infos: + obj = FileDocModel( + kb_name=kb_name, + file_name=file_name, + doc_id=d["id"], + meta_data=d["metadata"], + ) + session.add(obj) + return True + + +@with_session +def count_files_from_db(session, kb_name: str) -> int: + return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count() + + +@with_session +def list_files_from_db(session, kb_name): files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all() docs = [f.file_name for f in files] return docs @with_session -def add_doc_to_db(session, kb_file: KnowledgeFile): +def add_file_to_db(session, + kb_file: KnowledgeFile, + docs_count: int = 0, + custom_docs: bool = False, + doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...] + ): kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() if kb: - # 如果已经存在该文件,则更新文件版本号 - existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, - kb_name=kb_file.kb_name).first() + # 如果已经存在该文件,则更新文件信息与版本号 + existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel) + .filter_by(file_name=kb_file.filename, + kb_name=kb_file.kb_name) + .first()) + mtime = kb_file.get_mtime() + size = kb_file.get_size() + if existing_file: + existing_file.file_mtime = mtime + existing_file.file_size = size + existing_file.docs_count = docs_count + existing_file.custom_docs = custom_docs existing_file.file_version += 1 # 否则,添加新文件 else: @@ -28,9 +105,14 @@ def add_doc_to_db(session, kb_file: KnowledgeFile): kb_name=kb_file.kb_name, document_loader_name=kb_file.document_loader_name, text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter", + file_mtime=mtime, + file_size=size, + docs_count = docs_count, + custom_docs=custom_docs, ) kb.file_count += 1 session.add(new_file) + add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos) return True @@ -40,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile): kb_name=kb_file.kb_name).first() if existing_file: session.delete(existing_file) + delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename) session.commit() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() @@ -52,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile): @with_session def delete_files_from_db(session, knowledge_base_name: str): session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete() - + session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first() if kb: kb.file_count = 0 @@ -62,7 +145,7 @@ def delete_files_from_db(session, knowledge_base_name: str): @with_session -def doc_exists(session, kb_file: KnowledgeFile): +def file_exists_in_db(session, kb_file: KnowledgeFile): existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename, kb_name=kb_file.kb_name).first() return True if existing_file else False @@ -82,6 +165,10 @@ def get_file_detail(session, kb_name: str, filename: str) -> dict: "document_loader": file.document_loader_name, "text_splitter": file.text_splitter_name, "create_time": file.create_time, + "file_mtime": file.file_mtime, + "file_size": file.file_size, + "custom_docs": file.custom_docs, + "docs_count": file.docs_count, } else: return {} diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index ae027c12..7ea5d271 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -3,7 +3,7 @@ import urllib from fastapi import File, Form, Body, Query, UploadFile from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile +from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile from fastapi.responses import StreamingResponse, FileResponse import json from server.knowledge_base.kb_service.base import KBServiceFactory @@ -29,7 +29,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" return data -async def list_docs( +async def list_files( knowledge_base_name: str ) -> ListResponse: if not validate_kb_name(knowledge_base_name): @@ -40,7 +40,7 @@ async def list_docs( if kb is None: return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: - all_doc_names = kb.list_docs() + all_doc_names = kb.list_files() return ListResponse(data=all_doc_names) @@ -183,14 +183,14 @@ async def recreate_vector_store( set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. ''' - async def output(): + def output(): kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} else: kb.create_kb() kb.clear_vs() - docs = list_docs_from_folder(knowledge_base_name) + docs = list_files_from_folder(knowledge_base_name) for i, doc in enumerate(docs): try: kb_file = KnowledgeFile(doc, knowledge_base_name) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 8d1de489..ca0919e0 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -1,25 +1,32 @@ +import operator from abc import ABC, abstractmethod import os +import numpy as np from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document +from sklearn.preprocessing import normalize + from server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db, get_kb_detail, ) from server.db.repository.knowledge_file_repository import ( - add_doc_to_db, delete_file_from_db, delete_files_from_db, doc_exists, - list_docs_from_db, get_file_detail, delete_file_from_db + add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, + count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db, + list_docs_from_db, ) from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - EMBEDDING_DEVICE, EMBEDDING_MODEL) + EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, - list_kbs_from_folder, list_docs_from_folder, + list_kbs_from_folder, list_files_from_folder, ) +from server.utils import embedding_device from typing import List, Union, Dict +from typing import List, Union, Dict, Optional class SupportedVSType: @@ -41,7 +48,7 @@ class KBService(ABC): self.doc_path = get_doc_path(self.kb_name) self.do_init() - def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings: + def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) def create_kb(self): @@ -62,7 +69,6 @@ class KBService(ABC): status = delete_files_from_db(self.kb_name) return status - def drop_kb(self): """ 删除知识库 @@ -71,16 +77,24 @@ class KBService(ABC): status = delete_kb_from_db(self.kb_name) return status - def add_doc(self, kb_file: KnowledgeFile, **kwargs): + def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 向知识库添加文件 + 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True """ - docs = kb_file.file2text() + if docs: + custom_docs = True + else: + docs = kb_file.file2text() + custom_docs = False + if docs: self.delete_doc(kb_file) - embeddings = self._load_embeddings() - self.do_add_doc(docs, embeddings, **kwargs) - status = add_doc_to_db(kb_file) + doc_infos = self.do_add_doc(docs, **kwargs) + status = add_file_to_db(kb_file, + custom_docs=custom_docs, + docs_count=len(docs), + doc_infos=doc_infos) else: status = False return status @@ -95,20 +109,24 @@ class KBService(ABC): os.remove(kb_file.filepath) return status - def update_doc(self, kb_file: KnowledgeFile, **kwargs): + def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 使用content中的文件更新向量库 + 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True """ if os.path.exists(kb_file.filepath): self.delete_doc(kb_file, **kwargs) - return self.add_doc(kb_file, **kwargs) - + return self.add_doc(kb_file, docs=docs, **kwargs) + def exist_doc(self, file_name: str): - return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, + return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name)) - def list_docs(self): - return list_docs_from_db(self.kb_name) + def list_files(self): + return list_files_from_db(self.kb_name) + + def count_files(self): + return count_files_from_db(self.kb_name) def search_docs(self, query: str, @@ -119,6 +137,18 @@ class KBService(ABC): docs = self.do_search(query, top_k, score_threshold, embeddings) return docs + # TODO: milvus/pg需要实现该方法 + def get_doc_by_id(self, id: str) -> Optional[Document]: + return None + + def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]: + ''' + 通过file_name或metadata检索Document + ''' + doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) + docs = [self.get_doc_by_id(x["id"]) for x in doc_infos] + return docs + @abstractmethod def do_create_kb(self): """ @@ -168,8 +198,7 @@ class KBService(ABC): @abstractmethod def do_add_doc(self, docs: List[Document], - embeddings: Embeddings, - ): + ) -> List[Dict]: """ 向知识库添加文档子类实自己逻辑 """ @@ -208,8 +237,9 @@ class KBServiceFactory: return PGKBService(kb_name, embed_model=embed_model) elif SupportedVSType.MILVUS == vector_store_type: from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config - elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. + return MilvusKBService(kb_name, + embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config + elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. from server.knowledge_base.kb_service.default_kb_service import DefaultKBService return DefaultKBService(kb_name) @@ -217,7 +247,7 @@ class KBServiceFactory: def get_service_by_name(kb_name: str ) -> KBService: _, vs_type, embed_model = load_kb_from_db(kb_name) - if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db + if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db vs_type = "faiss" return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @@ -256,29 +286,30 @@ def get_kb_details() -> List[Dict]: for i, v in enumerate(result.values()): v['No'] = i + 1 data.append(v) - + return data -def get_kb_doc_details(kb_name: str) -> List[Dict]: +def get_kb_file_details(kb_name: str) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(kb_name) - docs_in_folder = list_docs_from_folder(kb_name) - docs_in_db = kb.list_docs() + files_in_folder = list_files_from_folder(kb_name) + files_in_db = kb.list_files() result = {} - for doc in docs_in_folder: + for doc in files_in_folder: result[doc] = { "kb_name": kb_name, "file_name": doc, "file_ext": os.path.splitext(doc)[-1], "file_version": 0, "document_loader": "", + "docs_count": 0, "text_splitter": "", "create_time": None, "in_folder": True, "in_db": False, } - for doc in docs_in_db: + for doc in files_in_db: doc_detail = get_file_detail(kb_name, doc) if doc_detail: doc_detail["in_db"] = True @@ -292,5 +323,39 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]: for i, v in enumerate(result.values()): v['No'] = i + 1 data.append(v) - + return data + + +class EmbeddingsFunAdapter(Embeddings): + + def __init__(self, embeddings: Embeddings): + self.embeddings = embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return normalize(self.embeddings.embed_documents(texts)) + + def embed_query(self, text: str) -> List[float]: + query_embed = self.embeddings.embed_query(text) + query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 + normalized_query_embed = normalize(query_embed_2d) + return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return await normalize(self.embeddings.aembed_documents(texts)) + + async def aembed_query(self, text: str) -> List[float]: + return await normalize(self.embeddings.aembed_query(text)) + + +def score_threshold_process(score_threshold, k, docs): + if score_threshold is not None: + cmp = ( + operator.le + ) + docs = [ + (doc, similarity) + for doc, similarity in docs + if cmp(similarity, score_threshold) + ] + return docs[:k] diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py index 922e39be..a68d59c5 100644 --- a/server/knowledge_base/kb_service/default_kb_service.py +++ b/server/knowledge_base/kb_service/default_kb_service.py @@ -13,7 +13,7 @@ class DefaultKBService(KBService): def do_drop_kb(self): pass - def do_add_doc(self, docs: List[Document], embeddings: Embeddings): + def do_add_doc(self, docs: List[Document]): pass def do_clear_vs(self): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index b3f5439b..15cc790b 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,7 +5,6 @@ from configs.model_config import ( KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, - EMBEDDING_DEVICE, SCORE_THRESHOLD ) from server.knowledge_base.kb_service.base import KBService, SupportedVSType @@ -13,40 +12,22 @@ from functools import lru_cache from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings -from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings -from langchain.embeddings.openai import OpenAIEmbeddings -from typing import List +from typing import List, Dict, Optional from langchain.docstore.document import Document -from server.utils import torch_gc - - -# make HuggingFaceEmbeddings hashable -def _embeddings_hash(self): - if isinstance(self, HuggingFaceEmbeddings): - return hash(self.model_name) - elif isinstance(self, HuggingFaceBgeEmbeddings): - return hash(self.model_name) - elif isinstance(self, OpenAIEmbeddings): - return hash(self.model) - -HuggingFaceEmbeddings.__hash__ = _embeddings_hash -OpenAIEmbeddings.__hash__ = _embeddings_hash -HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash - -_VECTOR_STORE_TICKS = {} +from server.utils import torch_gc, embedding_device _VECTOR_STORE_TICKS = {} @lru_cache(CACHED_VS_NUM) -def load_vector_store( +def load_faiss_vector_store( knowledge_base_name: str, embed_model: str = EMBEDDING_MODEL, - embed_device: str = EMBEDDING_DEVICE, + embed_device: str = embedding_device(), embeddings: Embeddings = None, tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. -): +) -> FAISS: print(f"loading vector store in '{knowledge_base_name}'.") vs_path = get_vs_path(knowledge_base_name) if embeddings is None: @@ -86,22 +67,39 @@ class FaissKBService(KBService): def vs_type(self) -> str: return SupportedVSType.FAISS - @staticmethod - def get_vs_path(knowledge_base_name: str): - return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store") + def get_vs_path(self): + return os.path.join(self.get_kb_path(), "vector_store") - @staticmethod - def get_kb_path(knowledge_base_name: str): - return os.path.join(KB_ROOT_PATH, knowledge_base_name) + def get_kb_path(self): + return os.path.join(KB_ROOT_PATH, self.kb_name) + + def load_vector_store(self) -> FAISS: + return load_faiss_vector_store( + knowledge_base_name=self.kb_name, + embed_model=self.embed_model, + tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), + ) + + def save_vector_store(self, vector_store: FAISS = None): + vector_store = vector_store or self.load_vector_store() + vector_store.save_local(self.vs_path) + return vector_store + + def refresh_vs_cache(self): + refresh_vs_cache(self.kb_name) + + def get_doc_by_id(self, id: str) -> Optional[Document]: + vector_store = self.load_vector_store() + return vector_store.docstore._dict.get(id) def do_init(self): - self.kb_path = FaissKBService.get_kb_path(self.kb_name) - self.vs_path = FaissKBService.get_vs_path(self.kb_name) + self.kb_path = self.get_kb_path() + self.vs_path = self.get_vs_path() def do_create_kb(self): if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) - load_vector_store(self.kb_name) + self.load_vector_store() def do_drop_kb(self): self.clear_vs() @@ -113,33 +111,27 @@ class FaissKBService(KBService): score_threshold: float = SCORE_THRESHOLD, embeddings: Embeddings = None, ) -> List[Document]: - search_index = load_vector_store(self.kb_name, - embeddings=embeddings, - tick=_VECTOR_STORE_TICKS.get(self.kb_name)) + search_index = self.load_vector_store() docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) return docs def do_add_doc(self, docs: List[Document], - embeddings: Embeddings, **kwargs, - ): - vector_store = load_vector_store(self.kb_name, - embeddings=embeddings, - tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) - vector_store.add_documents(docs) + ) -> List[Dict]: + vector_store = self.load_vector_store() + ids = vector_store.add_documents(docs) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] torch_gc() if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) + self.refresh_vs_cache() + return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - embeddings = self._load_embeddings() - vector_store = load_vector_store(self.kb_name, - embeddings=embeddings, - tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + vector_store = self.load_vector_store() ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] if len(ids) == 0: @@ -148,14 +140,14 @@ class FaissKBService(KBService): vector_store.delete(ids) if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) + self.refresh_vs_cache() - return True + return vector_store def do_clear_vs(self): shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) - refresh_vs_cache(self.kb_name) + self.refresh_vs_cache() def exist_doc(self, file_name: str): if super().exist_doc(file_name): @@ -166,3 +158,11 @@ class FaissKBService(KBService): return "in_folder" else: return False + + +if __name__ == '__main__': + faissService = FaissKBService("test") + faissService.add_doc(KnowledgeFile("README.md", "test")) + faissService.delete_doc(KnowledgeFile("README.md", "test")) + faissService.do_drop_kb() + print(faissService.search_docs("如何启动api服务")) \ No newline at end of file diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index f2c798cf..444765f6 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -1,12 +1,16 @@ -from typing import List +from typing import List, Dict, Optional +import numpy as np +from faiss import normalize_L2 from langchain.embeddings.base import Embeddings from langchain.schema import Document from langchain.vectorstores import Milvus +from sklearn.preprocessing import normalize from configs.model_config import SCORE_THRESHOLD, kbs_config -from server.knowledge_base.kb_service.base import KBService, SupportedVSType +from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \ + score_threshold_process from server.knowledge_base.utils import KnowledgeFile @@ -18,6 +22,14 @@ class MilvusKBService(KBService): from pymilvus import Collection return Collection(milvus_name) + def get_doc_by_id(self, id: str) -> Optional[Document]: + if self.milvus.col: + data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"]) + if len(data_list) > 0: + data = data_list[0] + text = data.pop("text") + return Document(page_content=text, metadata=data) + @staticmethod def search(milvus_name, content, limit=3): search_params = { @@ -36,38 +48,31 @@ class MilvusKBService(KBService): def _load_milvus(self, embeddings: Embeddings = None): if embeddings is None: embeddings = self._load_embeddings() - self.milvus = Milvus(embedding_function=embeddings, + self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings), collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) def do_init(self): self._load_milvus() def do_drop_kb(self): - self.milvus.col.drop() + if self.milvus.col: + self.milvus.col.drop() - def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings): - # todo: support score threshold - self._load_milvus(embeddings=embeddings) - return self.milvus.similarity_search_with_score(query, top_k) + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): + self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) + return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) - def add_doc(self, kb_file: KnowledgeFile, **kwargs): - """ - 向知识库添加文件 - """ - docs = kb_file.file2text() - self.milvus.add_documents(docs) - from server.db.repository.knowledge_file_repository import add_doc_to_db - status = add_doc_to_db(kb_file) - return status - - def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): - pass + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: + ids = self.milvus.add_documents(docs) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] + return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - filepath = kb_file.filepath.replace('\\', '\\\\') - delete_list = [item.get("pk") for item in - self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] - self.milvus.col.delete(expr=f'pk in {delete_list}') + if self.milvus.col: + filepath = kb_file.filepath.replace('\\', '\\\\') + delete_list = [item.get("pk") for item in + self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])] + self.milvus.col.delete(expr=f'pk in {delete_list}') def do_clear_vs(self): if self.milvus.col: @@ -80,7 +85,9 @@ if __name__ == '__main__': Base.metadata.create_all(bind=engine) milvusService = MilvusKBService("test") - milvusService.add_doc(KnowledgeFile("README.md", "test")) - milvusService.delete_doc(KnowledgeFile("README.md", "test")) - milvusService.do_drop_kb() - print(milvusService.search_docs("测试")) + # milvusService.add_doc(KnowledgeFile("README.md", "test")) + + print(milvusService.get_doc_by_id("444022434274215486")) + # milvusService.delete_doc(KnowledgeFile("README.md", "test")) + # milvusService.do_drop_kb() + # print(milvusService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 6876bd86..e6381fac 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -1,25 +1,38 @@ -from typing import List +import json +from typing import List, Dict, Optional from langchain.embeddings.base import Embeddings from langchain.schema import Document from langchain.vectorstores import PGVector +from langchain.vectorstores.pgvector import DistanceStrategy from sqlalchemy import text from configs.model_config import EMBEDDING_DEVICE, kbs_config -from server.knowledge_base.kb_service.base import SupportedVSType, KBService + +from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ + score_threshold_process from server.knowledge_base.utils import load_embeddings, KnowledgeFile +from server.utils import embedding_device as get_embedding_device class PGKBService(KBService): pg_vector: PGVector - def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): + def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None): _embeddings = embeddings if _embeddings is None: _embeddings = load_embeddings(self.embed_model, embedding_device) - self.pg_vector = PGVector(embedding_function=_embeddings, + self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings), collection_name=self.kb_name, + distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) + def get_doc_by_id(self, id: str) -> Optional[Document]: + with self.pg_vector.connect() as connect: + stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id") + results = [Document(page_content=row[0], metadata=row[1]) for row in + connect.execute(stmt, parameters={'id': id}).fetchall()] + if len(results) > 0: + return results[0] def do_init(self): self._load_pg_vector() @@ -44,22 +57,14 @@ class PGKBService(KBService): connect.commit() def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): - # todo: support score threshold self._load_pg_vector(embeddings=embeddings) - return self.pg_vector.similarity_search_with_score(query, top_k) + return score_threshold_process(score_threshold, top_k, + self.pg_vector.similarity_search_with_score(query, top_k)) - def add_doc(self, kb_file: KnowledgeFile, **kwargs): - """ - 向知识库添加文件 - """ - docs = kb_file.file2text() - self.pg_vector.add_documents(docs) - from server.db.repository.knowledge_file_repository import add_doc_to_db - status = add_doc_to_db(kb_file) - return status - - def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs): - pass + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: + ids = self.pg_vector.add_documents(docs) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] + return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): with self.pg_vector.connect() as connect: @@ -77,10 +82,11 @@ class PGKBService(KBService): if __name__ == '__main__': from server.db.base import Base, engine - Base.metadata.create_all(bind=engine) + # Base.metadata.create_all(bind=engine) pGKBService = PGKBService("test") - pGKBService.create_kb() - pGKBService.add_doc(KnowledgeFile("README.md", "test")) - pGKBService.delete_doc(KnowledgeFile("README.md", "test")) - pGKBService.drop_kb() - print(pGKBService.search_docs("测试")) + # pGKBService.create_kb() + # pGKBService.add_doc(KnowledgeFile("README.md", "test")) + # pGKBService.delete_doc(KnowledgeFile("README.md", "test")) + # pGKBService.drop_kb() + print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772")) + # print(pGKBService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index c96d3867..129dd53c 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -1,10 +1,17 @@ from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE -from server.knowledge_base.utils import get_file_path, list_kbs_from_folder, list_docs_from_folder, KnowledgeFile -from server.knowledge_base.kb_service.base import KBServiceFactory -from server.db.repository.knowledge_file_repository import add_doc_to_db +from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, + list_files_from_folder, run_in_thread_pool, + files2docs_in_thread, + KnowledgeFile,) +from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType +from server.db.repository.knowledge_file_repository import add_file_to_db from server.db.base import Base, engine import os -from typing import Literal, Callable, Any +from concurrent.futures import ThreadPoolExecutor +from typing import Literal, Any, List + + +pool = ThreadPoolExecutor(os.cpu_count()) def create_tables(): @@ -16,13 +23,22 @@ def reset_tables(): create_tables() +def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: + kb_files = [] + for file in files: + try: + kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name) + kb_files.append(kb_file) + except Exception as e: + print(f"{e},已跳过") + return kb_files + + def folder2db( kb_name: str, mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, - callback_before: Callable = None, - callback_after: Callable = None, ): ''' use existed files in local folder to populate database and/or vector store. @@ -36,70 +52,62 @@ def folder2db( kb.create_kb() if mode == "recreate_vs": + files_count = kb.count_files() + print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。") kb.clear_vs() - docs = list_docs_from_folder(kb_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, kb_name) - if callable(callback_before): - callback_before(kb_file, i, docs) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - if callable(callback_after): - callback_after(kb_file, i, docs) - except Exception as e: - print(e) + files_count = kb.count_files() + print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。") + + kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) + for success, result in files2docs_in_thread(kb_files, pool=pool): + if success: + _, filename, docs = result + print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档") + kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) + else: + print(result) + + if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() + kb.refresh_vs_cache() elif mode == "fill_info_only": - docs = list_docs_from_folder(kb_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, kb_name) - if callable(callback_before): - callback_before(kb_file, i, docs) - add_doc_to_db(kb_file) - if callable(callback_after): - callback_after(kb_file, i, docs) - except Exception as e: - print(e) + files = list_files_from_folder(kb_name) + kb_files = file_to_kbfile(kb_name, files) + + for kb_file in kb_file: + add_file_to_db(kb_file) + print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库") elif mode == "update_in_db": - docs = kb.list_docs() - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, kb_name) - if callable(callback_before): - callback_before(kb_file, i, docs) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - if callable(callback_after): - callback_after(kb_file, i, docs) - except Exception as e: - print(e) + files = kb.list_files() + kb_files = file_to_kbfile(kb_name, files) + + for kb_file in kb_files: + kb.update_doc(kb_file, not_refresh_vs_cache=True) + + if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() + kb.refresh_vs_cache() elif mode == "increament": - db_docs = kb.list_docs() - folder_docs = list_docs_from_folder(kb_name) - docs = list(set(folder_docs) - set(db_docs)) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, kb_name) - if callable(callback_before): - callback_before(kb_file, i, docs) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - if callable(callback_after): - callback_after(kb_file, i, docs) - except Exception as e: - print(e) + db_files = kb.list_files() + folder_files = list_files_from_folder(kb_name) + files = list(set(folder_files) - set(db_files)) + kb_files = file_to_kbfile(kb_name, files) + + for success, result in files2docs_in_thread(kb_files, pool=pool): + if success: + _, filename, docs = result + print(f"正在将 {kb_name}/{filename} 添加到向量库") + kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) + else: + print(result) + + if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() + kb.refresh_vs_cache() else: - raise ValueError(f"unspported migrate mode: {mode}") + print(f"unspported migrate mode: {mode}") def recreate_all_vs( @@ -114,30 +122,34 @@ def recreate_all_vs( folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs) -def prune_db_docs(kb_name: str): +def prune_db_files(kb_name: str): ''' - delete docs in database that not existed in local folder. - it is used to delete database docs after user deleted some doc files in file browser + delete files in database that not existed in local folder. + it is used to delete database files after user deleted some doc files in file browser ''' kb = KBServiceFactory.get_service_by_name(kb_name) if kb.exists(): - docs_in_db = kb.list_docs() - docs_in_folder = list_docs_from_folder(kb_name) - docs = list(set(docs_in_db) - set(docs_in_folder)) - for doc in docs: - kb.delete_doc(KnowledgeFile(doc, kb_name)) - return docs + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_db) - set(files_in_folder)) + kb_files = file_to_kbfile(kb_name, files) + for kb_file in kb_files: + kb.delete_doc(kb_file, not_refresh_vs_cache=True) + if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() + kb.refresh_vs_cache() + return kb_files -def prune_folder_docs(kb_name: str): +def prune_folder_files(kb_name: str): ''' delete doc files in local folder that not existed in database. is is used to free local disk space by delete unused doc files. ''' kb = KBServiceFactory.get_service_by_name(kb_name) if kb.exists(): - docs_in_db = kb.list_docs() - docs_in_folder = list_docs_from_folder(kb_name) - docs = list(set(docs_in_folder) - set(docs_in_db)) - for doc in docs: - os.remove(get_file_path(kb_name, doc)) - return docs + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_folder) - set(files_in_db)) + for file in files: + os.remove(get_file_path(kb_name, file)) + return files diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 34f20832..a8a9bcca 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -12,6 +12,26 @@ from configs.model_config import ( from functools import lru_cache import importlib from text_splitter import zh_title_enhance +import langchain.document_loaders +from langchain.docstore.document import Document +from pathlib import Path +import json +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Union, Callable, Dict, Optional, Tuple, Generator + + +# make HuggingFaceEmbeddings hashable +def _embeddings_hash(self): + if isinstance(self, HuggingFaceEmbeddings): + return hash(self.model_name) + elif isinstance(self, HuggingFaceBgeEmbeddings): + return hash(self.model_name) + elif isinstance(self, OpenAIEmbeddings): + return hash(self.model) + +HuggingFaceEmbeddings.__hash__ = _embeddings_hash +OpenAIEmbeddings.__hash__ = _embeddings_hash +HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash def validate_kb_name(knowledge_base_id: str) -> bool: @@ -20,27 +40,34 @@ def validate_kb_name(knowledge_base_id: str) -> bool: return False return True + def get_kb_path(knowledge_base_name: str): return os.path.join(KB_ROOT_PATH, knowledge_base_name) + def get_doc_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "content") + def get_vs_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "vector_store") + def get_file_path(knowledge_base_name: str, doc_name: str): return os.path.join(get_doc_path(knowledge_base_name), doc_name) + def list_kbs_from_folder(): return [f for f in os.listdir(KB_ROOT_PATH) if os.path.isdir(os.path.join(KB_ROOT_PATH, f))] -def list_docs_from_folder(kb_name: str): + +def list_files_from_folder(kb_name: str): doc_path = get_doc_path(kb_name) return [file for file in os.listdir(doc_path) if os.path.isfile(os.path.join(doc_path, file))] + @lru_cache(1) def load_embeddings(model: str, device: str): if model == "text-embedding-ada-002": # openai text-embedding-ada-002 @@ -56,16 +83,90 @@ def load_embeddings(model: str, device: str): return embeddings - -LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst', - '.rtf', '.txt', '.xml', - '.doc', '.docx', '.epub', '.odt', '.pdf', - '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv' +LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], + "UnstructuredMarkdownLoader": ['.md'], + "CustomJSONLoader": [".json"], "CSVLoader": [".csv"], - "PyPDFLoader": [".pdf"], + "RapidOCRPDFLoader": [".pdf"], + "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], + "UnstructuredFileLoader": ['.eml', '.msg', '.rst', + '.rtf', '.txt', '.xml', + '.doc', '.docx', '.epub', '.odt', + '.ppt', '.pptx', '.tsv'], # '.xlsx' } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] + +class CustomJSONLoader(langchain.document_loaders.JSONLoader): + ''' + langchain的JSONLoader需要jq,在win上使用不便,进行替代。 + ''' + + def __init__( + self, + file_path: Union[str, Path], + content_key: Optional[str] = None, + metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, + json_lines: bool = False, + ): + """Initialize the JSONLoader. + + Args: + file_path (Union[str, Path]): The path to the JSON or JSON Lines file. + content_key (str): The key to use to extract the content from the JSON if + results to a list of objects (dict). + metadata_func (Callable[Dict, Dict]): A function that takes in the JSON + object extracted by the jq_schema and the default metadata and returns + a dict of the updated metadata. + text_content (bool): Boolean flag to indicate whether the content is in + string format, default to True. + json_lines (bool): Boolean flag to indicate whether the input is in + JSON Lines format. + """ + self.file_path = Path(file_path).resolve() + self._content_key = content_key + self._metadata_func = metadata_func + self._text_content = text_content + self._json_lines = json_lines + + # TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows. + # This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded. + def load(self) -> List[Document]: + """Load and return documents from the JSON file.""" + docs: List[Document] = [] + if self._json_lines: + with self.file_path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + self._parse(line, docs) + else: + self._parse(self.file_path.read_text(encoding="utf-8"), docs) + return docs + + def _parse(self, content: str, docs: List[Document]) -> None: + """Convert given content to documents.""" + data = json.loads(content) + + # Perform some validation + # This is not a perfect validation, but it should catch most cases + # and prevent the user from getting a cryptic error later on. + if self._content_key is not None: + self._validate_content_key(data) + + for i, sample in enumerate(data, len(docs) + 1): + metadata = dict( + source=str(self.file_path), + seq_num=i, + ) + text = self._get_text(sample=sample, metadata=metadata) + docs.append(Document(page_content=text, metadata=metadata)) + + +langchain.document_loaders.CustomJSONLoader = CustomJSONLoader + + def get_LoaderClass(file_extension): for LoaderClass, extensions in LOADER_DICT.items(): if file_extension in extensions: @@ -90,10 +191,16 @@ class KnowledgeFile: # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = None - def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE): - print(self.document_loader_name) + def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False): + if self.docs is not None and not refresh: + return self.docs + + print(f"{self.document_loader_name} used for {self.filepath}") try: - document_loaders_module = importlib.import_module('langchain.document_loaders') + if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + document_loaders_module = importlib.import_module('document_loaders') + else: + document_loaders_module = importlib.import_module('langchain.document_loaders') DocumentLoader = getattr(document_loaders_module, self.document_loader_name) except Exception as e: print(e) @@ -101,6 +208,16 @@ class KnowledgeFile: DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") if self.document_loader_name == "UnstructuredFileLoader": loader = DocumentLoader(self.filepath, autodetect_encoding=True) + elif self.document_loader_name == "CSVLoader": + loader = DocumentLoader(self.filepath, encoding="utf-8") + elif self.document_loader_name == "JSONLoader": + loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) + elif self.document_loader_name == "CustomJSONLoader": + loader = DocumentLoader(self.filepath, text_content=False) + elif self.document_loader_name == "UnstructuredMarkdownLoader": + loader = DocumentLoader(self.filepath, mode="elements") + elif self.document_loader_name == "UnstructuredHTMLLoader": + loader = DocumentLoader(self.filepath, mode="elements") else: loader = DocumentLoader(self.filepath) @@ -136,4 +253,63 @@ class KnowledgeFile: print(docs[0]) if using_zh_title_enhance: docs = zh_title_enhance(docs) + self.docs = docs return docs + + def get_mtime(self): + return os.path.getmtime(self.filepath) + + def get_size(self): + return os.path.getsize(self.filepath) + + +def run_in_thread_pool( + func: Callable, + params: List[Dict] = [], + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 + 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 + ''' + tasks = [] + if pool is None: + pool = ThreadPoolExecutor() + + for kwargs in params: + thread = pool.submit(func, **kwargs) + tasks.append(thread) + + for obj in as_completed(tasks): + yield obj.result() + + +def files2docs_in_thread( + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 利用多线程批量将文件转化成langchain Document. + 生成器返回值为{(kb_name, file_name): docs} + ''' + def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]: + try: + return True, (file.kb_name, file.filename, file.file2text(**kwargs)) + except Exception as e: + return False, e + + kwargs_list = [] + for i, file in enumerate(files): + kwargs = {} + if isinstance(file, tuple) and len(file) >= 2: + files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) + elif isinstance(file, dict): + filename = file.pop("filename") + kb_name = file.pop("kb_name") + files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kwargs = file + kwargs["file"] = file + kwargs_list.append(kwargs) + + for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool): + yield result diff --git a/server/llm_api.py b/server/llm_api.py index ab71b3db..d9667e4f 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -4,8 +4,8 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger -from server.utils import MakeFastAPIOffline +from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger +from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device host_ip = "0.0.0.0" @@ -15,13 +15,6 @@ openai_api_port = 8888 base_url = "http://127.0.0.1:{}" -def set_httpx_timeout(timeout=60.0): - import httpx - httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout - - def create_controller_app( dispatch_method="shortest_queue", ): @@ -41,7 +34,7 @@ def create_model_worker_app( worker_address=base_url.format(model_worker_port), controller_address=base_url.format(controller_port), model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), - device=LLM_DEVICE, + device=llm_device(), gpus=None, max_gpu_memory="20GiB", load_8bit=False, diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py new file mode 100644 index 00000000..932c2f3b --- /dev/null +++ b/server/model_workers/__init__.py @@ -0,0 +1 @@ +from .zhipu import ChatGLMWorker diff --git a/server/model_workers/base.py b/server/model_workers/base.py new file mode 100644 index 00000000..b72f6839 --- /dev/null +++ b/server/model_workers/base.py @@ -0,0 +1,71 @@ +from configs.model_config import LOG_PATH +import fastchat.constants +fastchat.constants.LOGDIR = LOG_PATH +from fastchat.serve.model_worker import BaseModelWorker +import uuid +import json +import sys +from pydantic import BaseModel +import fastchat +import threading +from typing import Dict, List + + +# 恢复被fastchat覆盖的标准输出 +sys.stdout = sys.__stdout__ +sys.stderr = sys.__stderr__ + + +class ApiModelOutMsg(BaseModel): + error_code: int = 0 + text: str + +class ApiModelWorker(BaseModelWorker): + BASE_URL: str + SUPPORT_MODELS: List + + def __init__( + self, + model_names: List[str], + controller_addr: str, + worker_addr: str, + context_len: int = 2048, + **kwargs, + ): + kwargs.setdefault("worker_id", uuid.uuid4().hex[:8]) + kwargs.setdefault("model_path", "") + kwargs.setdefault("limit_worker_concurrency", 5) + super().__init__(model_names=model_names, + controller_addr=controller_addr, + worker_addr=worker_addr, + **kwargs) + self.context_len = context_len + self.init_heart_beat() + + def count_token(self, params): + # TODO:需要完善 + print("count token") + print(params) + prompt = params["prompt"] + return {"count": len(str(prompt)), "error_code": 0} + + def generate_stream_gate(self, params): + self.call_ct += 1 + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def get_embeddings(self, params): + print("embedding") + print(params) + + # workaround to make program exit with Ctrl+c + # it should be deleted after pr is merged by fastchat + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ) + self.heart_beat_thread.start() diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py new file mode 100644 index 00000000..4e4e15e0 --- /dev/null +++ b/server/model_workers/zhipu.py @@ -0,0 +1,75 @@ +import zhipuai +from server.model_workers.base import ApiModelWorker +from fastchat import conversation as conv +import sys +import json +from typing import List, Literal + + +class ChatGLMWorker(ApiModelWorker): + BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api" + SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"] + + def __init__( + self, + *, + model_names: List[str] = ["chatglm-api"], + version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std", + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 32768) + super().__init__(**kwargs) + self.version = version + + # 这里的是chatglm api的模板,其它API的conv_template需要定制 + self.conv = conv.Conversation( + name="chatglm-api", + system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + messages=[], + roles=["Human", "Assistant"], + sep="\n### ", + stop_str="###", + ) + + def generate_stream_gate(self, params): + # TODO: 支持stream参数,维护request_id,传过来的prompt也有问题 + from server.utils import get_model_worker_config + + super().generate_stream_gate(params) + zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key") + + response = zhipuai.model_api.sse_invoke( + model=self.version, + prompt=[{"role": "user", "content": params["prompt"]}], + temperature=params.get("temperature"), + top_p=params.get("top_p"), + incremental=False, + ) + for e in response.events(): + if e.event == "add": + yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0" + # TODO: 更健壮的消息处理 + # elif e.event == "finish": + # ... + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = ChatGLMWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20003", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20003) diff --git a/server/utils.py b/server/utils.py index 4a887225..0e53e3df 100644 --- a/server/utils.py +++ b/server/utils.py @@ -1,16 +1,20 @@ import pydantic from pydantic import BaseModel from typing import List -import torch from fastapi import FastAPI from pathlib import Path import asyncio -from typing import Any, Optional +from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE +from configs.server_config import FSCHAT_MODEL_WORKERS +import os +from server import model_workers +from typing import Literal, Optional, Any class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="API status code") msg: str = pydantic.Field("success", description="API status message") + data: Any = pydantic.Field(None, description="API data") class Config: schema_extra = { @@ -68,6 +72,7 @@ class ChatMessage(BaseModel): } def torch_gc(): + import torch if torch.cuda.is_available(): # with torch.cuda.device(DEVICE): torch.cuda.empty_cache() @@ -186,3 +191,117 @@ def MakeFastAPIOffline( with_google_fonts=False, redoc_favicon_url=favicon, ) + + +# 从server_config中获取服务信息 +def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: + ''' + 加载model worker的配置项。 + 优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"] + ''' + from configs.server_config import FSCHAT_MODEL_WORKERS + from configs.model_config import llm_model_dict + + config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() + config.update(llm_model_dict.get(model_name, {})) + config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) + + # 如果没有设置有效的local_model_path,则认为是在线模型API + if not os.path.isdir(config.get("local_model_path", "")): + config["online_api"] = True + if provider := config.get("provider"): + try: + config["worker_class"] = getattr(model_workers, provider) + except Exception as e: + print(f"在线模型 ‘{model_name}’ 的provider没有正确配置") + + config["device"] = llm_device(config.get("device") or LLM_DEVICE) + return config + + +def get_all_model_worker_configs() -> dict: + result = {} + model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys()) + for name in model_names: + if name != "default": + result[name] = get_model_worker_config(name) + return result + + +def fschat_controller_address() -> str: + from configs.server_config import FSCHAT_CONTROLLER + + host = FSCHAT_CONTROLLER["host"] + port = FSCHAT_CONTROLLER["port"] + return f"http://{host}:{port}" + + +def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: + if model := get_model_worker_config(model_name): + host = model["host"] + port = model["port"] + return f"http://{host}:{port}" + return "" + + +def fschat_openai_api_address() -> str: + from configs.server_config import FSCHAT_OPENAI_API + + host = FSCHAT_OPENAI_API["host"] + port = FSCHAT_OPENAI_API["port"] + return f"http://{host}:{port}" + + +def api_address() -> str: + from configs.server_config import API_SERVER + + host = API_SERVER["host"] + port = API_SERVER["port"] + return f"http://{host}:{port}" + + +def webui_address() -> str: + from configs.server_config import WEBUI_SERVER + + host = WEBUI_SERVER["host"] + port = WEBUI_SERVER["port"] + return f"http://{host}:{port}" + + +def set_httpx_timeout(timeout: float = None): + ''' + 设置httpx默认timeout。 + httpx默认timeout是5秒,在请求LLM回答时不够用。 + ''' + import httpx + from configs.server_config import HTTPX_DEFAULT_TIMEOUT + + timeout = timeout or HTTPX_DEFAULT_TIMEOUT + httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout + + +# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch +def detect_device() -> Literal["cuda", "mps", "cpu"]: + try: + import torch + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + except: + pass + return "cpu" + + +def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu"]: + device = detect_device() + return device + + +def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu"]: + device = detect_device() + return device diff --git a/startup.py b/startup.py index c8706e28..591e3b10 100644 --- a/startup.py +++ b/startup.py @@ -1,59 +1,60 @@ -from multiprocessing import Process, Queue +import asyncio import multiprocessing as mp +import os import subprocess import sys -import os +from multiprocessing import Process, Queue from pprint import pprint # 设置numexpr最大线程数,默认为CPU核心数 try: import numexpr + n_cores = numexpr.utils.detect_number_of_cores() os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores) except: pass sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ +from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ logger -from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, - FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, ) -from server.utils import MakeFastAPIOffline, FastAPI +from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, + FSCHAT_OPENAI_API, ) +from server.utils import (fschat_controller_address, fschat_model_worker_address, + fschat_openai_api_address, set_httpx_timeout, + get_model_worker_config, get_all_model_worker_configs, + MakeFastAPIOffline, FastAPI, llm_device, embedding_device) import argparse -from typing import Tuple, List +from typing import Tuple, List, Dict from configs import VERSION -def set_httpx_timeout(timeout=60.0): - import httpx - httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout - - def create_controller_app( dispatch_method: str, + log_level: str = "INFO", ) -> FastAPI: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.controller import app, Controller + from fastchat.serve.controller import app, Controller, logger + logger.setLevel(log_level) controller = Controller(dispatch_method) sys.modules["fastchat.serve.controller"].controller = controller MakeFastAPIOffline(app) app.title = "FastChat Controller" + app._controller = controller return app -def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: +def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger import argparse import threading import fastchat.serve.model_worker + logger.setLevel(log_level) # workaround to make program exit with Ctrl+c # it should be deleted after pr is merged by fastchat @@ -99,53 +100,76 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI] ) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus - gptq_config = GptqConfig( - ckpt=args.gptq_ckpt or args.model_path, - wbits=args.gptq_wbits, - groupsize=args.gptq_groupsize, - act_order=args.gptq_act_order, - ) - awq_config = AWQConfig( - ckpt=args.awq_ckpt or args.model_path, - wbits=args.awq_wbits, - groupsize=args.awq_groupsize, - ) + # 在线模型API + if worker_class := kwargs.get("worker_class"): + worker = worker_class(model_names=args.model_names, + controller_addr=args.controller_address, + worker_addr=args.worker_address) + # 本地模型 + else: + # workaround to make program exit with Ctrl+c + # it should be deleted after pr is merged by fastchat + def _new_init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ) + self.heart_beat_thread.start() - worker = ModelWorker( - controller_addr=args.controller_address, - worker_addr=args.worker_address, - worker_id=worker_id, - model_path=args.model_path, - model_names=args.model_names, - limit_worker_concurrency=args.limit_worker_concurrency, - no_register=args.no_register, - device=args.device, - num_gpus=args.num_gpus, - max_gpu_memory=args.max_gpu_memory, - load_8bit=args.load_8bit, - cpu_offloading=args.cpu_offloading, - gptq_config=gptq_config, - awq_config=awq_config, - stream_interval=args.stream_interval, - conv_template=args.conv_template, - ) + ModelWorker.init_heart_beat = _new_init_heart_beat + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + + worker = ModelWorker( + controller_addr=args.controller_address, + worker_addr=args.worker_address, + worker_id=worker_id, + model_path=args.model_path, + model_names=args.model_names, + limit_worker_concurrency=args.limit_worker_concurrency, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + ) + sys.modules["fastchat.serve.model_worker"].args = args + sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config sys.modules["fastchat.serve.model_worker"].worker = worker - sys.modules["fastchat.serve.model_worker"].args = args - sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config MakeFastAPIOffline(app) - app.title = f"FastChat LLM Server ({LLM_MODEL})" + app.title = f"FastChat LLM Server ({args.model_names[0]})" + app._worker = worker return app def create_openai_api_app( controller_address: str, api_keys: List = [], + log_level: str = "INFO", ) -> FastAPI: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings + from fastchat.utils import build_logger + logger = build_logger("openai_api", "openai_api.log") + logger.setLevel(log_level) app.add_middleware( CORSMiddleware, @@ -155,6 +179,7 @@ def create_openai_api_app( allow_headers=["*"], ) + sys.modules["fastchat.serve.openai_api_server"].logger = logger app_settings.controller_address = controller_address app_settings.api_keys = api_keys @@ -164,6 +189,9 @@ def create_openai_api_app( def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): + if q is None or not isinstance(run_seq, int): + return + if run_seq == 1: @app.on_event("startup") async def on_startup(): @@ -182,15 +210,90 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): q.put(run_seq) -def run_controller(q: Queue, run_seq: int = 1): +def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None): import uvicorn + import httpx + from fastapi import Body + import time + import sys - app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method")) + app = create_controller_app( + dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), + log_level=log_level, + ) _set_app_seq(app, q, run_seq) + @app.on_event("startup") + def on_startup(): + if e is not None: + e.set() + + # add interface to release and load model worker + @app.post("/release_worker") + def release_worker( + model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]), + # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]), + new_model_name: str = Body(None, description="释放后加载该模型"), + keep_origin: bool = Body(False, description="不释放原模型,加载新模型") + ) -> Dict: + available_models = app._controller.list_models() + if new_model_name in available_models: + msg = f"要切换的LLM模型 {new_model_name} 已经存在" + logger.info(msg) + return {"code": 500, "msg": msg} + + if new_model_name: + logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}") + else: + logger.info(f"即将停止LLM模型: {model_name}") + + if model_name not in available_models: + msg = f"the model {model_name} is not available" + logger.error(msg) + return {"code": 500, "msg": msg} + + worker_address = app._controller.get_worker_address(model_name) + if not worker_address: + msg = f"can not find model_worker address for {model_name}" + logger.error(msg) + return {"code": 500, "msg": msg} + + r = httpx.post(worker_address + "/release", + json={"new_model_name": new_model_name, "keep_origin": keep_origin}) + if r.status_code != 200: + msg = f"failed to release model: {model_name}" + logger.error(msg) + return {"code": 500, "msg": msg} + + if new_model_name: + timer = 300 # wait 5 minutes for new model_worker register + while timer > 0: + models = app._controller.list_models() + if new_model_name in models: + break + time.sleep(1) + timer -= 1 + if timer > 0: + msg = f"sucess change model from {model_name} to {new_model_name}" + logger.info(msg) + return {"code": 200, "msg": msg} + else: + msg = f"failed change model from {model_name} to {new_model_name}" + logger.error(msg) + return {"code": 500, "msg": msg} + else: + msg = f"sucess to release model: {model_name}" + logger.info(msg) + return {"code": 200, "msg": msg} + host = FSCHAT_CONTROLLER["host"] port = FSCHAT_CONTROLLER["port"] - uvicorn.run(app, host=host, port=port) + + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + + uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) def run_model_worker( @@ -198,33 +301,59 @@ def run_model_worker( controller_address: str = "", q: Queue = None, run_seq: int = 2, + log_level: str = "INFO", ): import uvicorn + from fastapi import Body + import sys - kwargs = FSCHAT_MODEL_WORKERS[model_name].copy() + kwargs = get_model_worker_config(model_name) host = kwargs.pop("host") port = kwargs.pop("port") - model_path = llm_model_dict[model_name].get("local_model_path", "") - kwargs["model_path"] = model_path kwargs["model_names"] = [model_name] kwargs["controller_address"] = controller_address or fschat_controller_address() - kwargs["worker_address"] = fschat_model_worker_address() + kwargs["worker_address"] = fschat_model_worker_address(model_name) + model_path = kwargs.get("local_model_path", "") + kwargs["model_path"] = model_path - app = create_model_worker_app(**kwargs) + app = create_model_worker_app(log_level=log_level, **kwargs) _set_app_seq(app, q, run_seq) + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ - uvicorn.run(app, host=host, port=port) + # add interface to release and load model + @app.post("/release") + def release_model( + new_model_name: str = Body(None, description="释放后加载该模型"), + keep_origin: bool = Body(False, description="不释放原模型,加载新模型") + ) -> Dict: + if keep_origin: + if new_model_name: + q.put(["start", new_model_name]) + else: + if new_model_name: + q.put(["replace", new_model_name]) + else: + q.put(["stop"]) + return {"code": 200, "msg": "done"} + + uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) -def run_openai_api(q: Queue, run_seq: int = 3): +def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"): import uvicorn + import sys controller_addr = fschat_controller_address() - app = create_openai_api_app(controller_addr) # todo: not support keys yet. + app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. _set_app_seq(app, q, run_seq) host = FSCHAT_OPENAI_API["host"] port = FSCHAT_OPENAI_API["port"] + if log_level == "ERROR": + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ uvicorn.run(app, host=host, port=port) @@ -244,13 +373,15 @@ def run_api_server(q: Queue, run_seq: int = 4): def run_webui(q: Queue, run_seq: int = 5): host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] - while True: - no = q.get() - if no != run_seq - 1: - q.put(no) - else: - break - q.put(run_seq) + + if q is not None and isinstance(run_seq, int): + while True: + no = q.get() + if no != run_seq - 1: + q.put(no) + else: + break + q.put(run_seq) p = subprocess.Popen(["streamlit", "run", "webui.py", "--server.address", host, "--server.port", str(port)]) @@ -313,6 +444,13 @@ def parse_args() -> argparse.ArgumentParser: help="run api.py server", dest="api", ) + parser.add_argument( + "-p", + "--api-worker", + action="store_true", + help="run online model api such as zhipuai", + dest="api_worker", + ) parser.add_argument( "-w", "--webui", @@ -320,26 +458,38 @@ def parse_args() -> argparse.ArgumentParser: help="run webui.py server", dest="webui", ) + parser.add_argument( + "-q", + "--quiet", + action="store_true", + help="减少fastchat服务log信息", + dest="quiet", + ) args = parser.parse_args() - return args + return args, parser -def dump_server_info(after_start=False): +def dump_server_info(after_start=False, args=None): import platform import langchain import fastchat - from configs.server_config import api_address, webui_address + from server.utils import api_address, webui_address - print("\n\n") + print("\n") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print(f"操作系统:{platform.platform()}.") print(f"python版本:{sys.version}") print(f"项目版本:{VERSION}") print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}") - pprint(llm_model_dict[LLM_MODEL]) - print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") + + model = LLM_MODEL + if args and args.model_name: + model = args.model_name + print(f"当前LLM模型:{model} @ {llm_device()}") + pprint(llm_model_dict[model]) + print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") + if after_start: print("\n") print(f"服务端运行信息:") @@ -351,110 +501,231 @@ def dump_server_info(after_start=False): if args.webui: print(f" Chatchat WEBUI Server: {webui_address()}") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) - print("\n\n") + print("\n") -if __name__ == "__main__": +async def start_main_server(): import time + import signal + + def handler(signalname): + """ + Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed. + Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose. + """ + def f(signal_received, frame): + raise KeyboardInterrupt(f"{signalname} received") + return f + + # This will be inherited by the child process if it is forked (not spawned) + signal.signal(signal.SIGINT, handler("SIGINT")) + signal.signal(signal.SIGTERM, handler("SIGTERM")) mp.set_start_method("spawn") - queue = Queue() - args = parse_args() + manager = mp.Manager() + + queue = manager.Queue() + args, parser = parse_args() + if args.all_webui: args.openai_api = True args.model_worker = True args.api = True + args.api_worker = True args.webui = True elif args.all_api: args.openai_api = True args.model_worker = True args.api = True + args.api_worker = True args.webui = False elif args.llm_api: args.openai_api = True args.model_worker = True + args.api_worker = True args.api = False args.webui = False - dump_server_info() - logger.info(f"正在启动服务:") - logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") + dump_server_info(args=args) - processes = {} + if len(sys.argv) > 1: + logger.info(f"正在启动服务:") + logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") + processes = {"online-api": []} + + def process_count(): + return len(processes) + len(processes["online-api"]) - 1 + + if args.quiet: + log_level = "ERROR" + else: + log_level = "INFO" + + controller_started = manager.Event() if args.openai_api: process = Process( target=run_controller, - name=f"controller({os.getpid()})", - args=(queue, len(processes) + 1), + name=f"controller", + args=(queue, process_count() + 1, log_level, controller_started), daemon=True, ) - process.start() + processes["controller"] = process process = Process( target=run_openai_api, - name=f"openai_api({os.getpid()})", - args=(queue, len(processes) + 1), + name=f"openai_api", + args=(queue, process_count() + 1), daemon=True, ) - process.start() processes["openai_api"] = process if args.model_worker: - process = Process( - target=run_model_worker, - name=f"model_worker({os.getpid()})", - args=(args.model_name, args.controller_address, queue, len(processes) + 1), - daemon=True, - ) - process.start() - processes["model_worker"] = process + config = get_model_worker_config(args.model_name) + if not config.get("online_api"): + process = Process( + target=run_model_worker, + name=f"model_worker - {args.model_name}", + args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level), + daemon=True, + ) + + processes["model_worker"] = process + + if args.api_worker: + configs = get_all_model_worker_configs() + for model_name, config in configs.items(): + if config.get("online_api") and config.get("worker_class"): + process = Process( + target=run_model_worker, + name=f"model_worker - {model_name}", + args=(model_name, args.controller_address, queue, process_count() + 1, log_level), + daemon=True, + ) + + processes["online-api"].append(process) if args.api: process = Process( target=run_api_server, - name=f"API Server{os.getpid()})", - args=(queue, len(processes) + 1), + name=f"API Server", + args=(queue, process_count() + 1), daemon=True, ) - process.start() + processes["api"] = process if args.webui: process = Process( target=run_webui, - name=f"WEBUI Server{os.getpid()})", - args=(queue, len(processes) + 1), + name=f"WEBUI Server", + args=(queue, process_count() + 1), daemon=True, ) - process.start() + processes["webui"] = process - try: - # log infors - while True: - no = queue.get() - if no == len(processes): - time.sleep(0.5) - dump_server_info(True) - break - else: - queue.put(no) + if process_count() == 0: + parser.print_help() + else: + try: + # 保证任务收到SIGINT后,能够正常退出 + if p:= processes.get("controller"): + p.start() + p.name = f"{p.name} ({p.pid})" + controller_started.wait() - if model_worker_process := processes.get("model_worker"): - model_worker_process.join() - for name, process in processes.items(): - if name != "model_worker": + if p:= processes.get("openai_api"): + p.start() + p.name = f"{p.name} ({p.pid})" + + if p:= processes.get("model_worker"): + p.start() + p.name = f"{p.name} ({p.pid})" + + for p in processes.get("online-api", []): + p.start() + p.name = f"{p.name} ({p.pid})" + + if p:= processes.get("api"): + p.start() + p.name = f"{p.name} ({p.pid})" + + if p:= processes.get("webui"): + p.start() + p.name = f"{p.name} ({p.pid})" + + while True: + no = queue.get() + if no == process_count(): + time.sleep(0.5) + dump_server_info(after_start=True, args=args) + break + else: + queue.put(no) + + if model_worker_process := processes.get("model_worker"): + model_worker_process.join() + for process in processes.get("online-api", []): process.join() - except: - if model_worker_process := processes.get("model_worker"): - model_worker_process.terminate() - for name, process in processes.items(): - if name != "model_worker": - process.terminate() + for name, process in processes.items(): + if name not in ["model_worker", "online-api"]: + if isinstance(p, list): + for work_process in p: + work_process.join() + else: + process.join() + except Exception as e: + # if model_worker_process := processes.pop("model_worker", None): + # model_worker_process.terminate() + # for process in processes.pop("online-api", []): + # process.terminate() + # for process in processes.values(): + # + # if isinstance(process, list): + # for work_process in process: + # work_process.terminate() + # else: + # process.terminate() + logger.error(e) + logger.warning("Caught KeyboardInterrupt! Setting stop event...") + finally: + # Send SIGINT if process doesn't exit quickly enough, and kill it as last resort + # .is_alive() also implicitly joins the process (good practice in linux) + # while alive_procs := [p for p in processes.values() if p.is_alive()]: + + for p in processes.values(): + logger.warning("Sending SIGKILL to %s", p) + # Queues and other inter-process communication primitives can break when + # process is killed, but we don't care here + + if isinstance(p, list): + for process in p: + process.kill() + + else: + p.kill() + + for p in processes.values(): + logger.info("Process status: %s", p) + +if __name__ == "__main__": + + if sys.version_info < (3, 10): + loop = asyncio.get_event_loop() + else: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + + asyncio.set_event_loop(loop) + # 同步调用协程代码 + loop.run_until_complete(start_main_server()) + # 服务启动后接口调用示例: # import openai diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 5a8b97d2..51bbac19 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -1,4 +1,3 @@ -from doctest import testfile import requests import json import sys @@ -6,7 +5,7 @@ from pathlib import Path root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) -from configs.server_config import api_address +from server.utils import api_address from configs.model_config import VECTOR_SEARCH_TOP_K from server.knowledge_base.utils import get_kb_path @@ -112,7 +111,7 @@ def test_upload_doc(api="/knowledge_base/upload_doc"): assert data["msg"] == f"成功上传文件 {name}" -def test_list_docs(api="/knowledge_base/list_docs"): +def test_list_files(api="/knowledge_base/list_files"): url = api_base_url + api print("\n获取知识库中文件列表:") r = requests.get(url, params={"knowledge_base_name": kb}) diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py new file mode 100644 index 00000000..f348fe74 --- /dev/null +++ b/tests/api/test_llm_api.py @@ -0,0 +1,74 @@ +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from configs.server_config import api_address, FSCHAT_MODEL_WORKERS +from configs.model_config import LLM_MODEL, llm_model_dict + +from pprint import pprint +import random + + +def get_configured_models(): + model_workers = list(FSCHAT_MODEL_WORKERS) + if "default" in model_workers: + model_workers.remove("default") + + llm_dict = list(llm_model_dict) + + return model_workers, llm_dict + + +api_base_url = api_address() + + +def get_running_models(api="/llm_model/list_models"): + url = api_base_url + api + r = requests.post(url) + if r.status_code == 200: + return r.json()["data"] + return [] + + +def test_running_models(api="/llm_model/list_models"): + url = api_base_url + api + r = requests.post(url) + assert r.status_code == 200 + print("\n获取当前正在运行的模型列表:") + pprint(r.json()) + assert isinstance(r.json()["data"], list) + assert len(r.json()["data"]) > 0 + + +# 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动 +# def test_stop_model(api="/llm_model/stop"): +# url = api_base_url + api +# r = requests.post(url, json={""}) + + +def test_change_model(api="/llm_model/change"): + url = api_base_url + api + + running_models = get_running_models() + assert len(running_models) > 0 + + model_workers, llm_dict = get_configured_models() + + availabel_new_models = set(model_workers) - set(running_models) + if len(availabel_new_models) == 0: + availabel_new_models = set(llm_dict) - set(running_models) + availabel_new_models = list(availabel_new_models) + assert len(availabel_new_models) > 0 + print(availabel_new_models) + + model_name = random.choice(running_models) + new_model_name = random.choice(availabel_new_models) + print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}") + r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name}) + assert r.status_code == 200 + + running_models = get_running_models() + assert new_model_name in running_models diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index ad9d3d89..4c2d5faf 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -5,7 +5,7 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent.parent)) from configs.model_config import BING_SUBSCRIPTION_KEY -from configs.server_config import API_SERVER, api_address +from server.utils import api_address from pprint import pprint diff --git a/tests/document_loader/test_imgloader.py b/tests/document_loader/test_imgloader.py new file mode 100644 index 00000000..8bba7da9 --- /dev/null +++ b/tests/document_loader/test_imgloader.py @@ -0,0 +1,21 @@ +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from pprint import pprint + +test_files = { + "ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"), +} + +def test_rapidocrpdfloader(): + pdf_path = test_files["ocr_test.pdf"] + from document_loaders import RapidOCRPDFLoader + + loader = RapidOCRPDFLoader(pdf_path) + docs = loader.load() + pprint(docs) + assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str) + + diff --git a/tests/document_loader/test_pdfloader.py b/tests/document_loader/test_pdfloader.py new file mode 100644 index 00000000..92460cb4 --- /dev/null +++ b/tests/document_loader/test_pdfloader.py @@ -0,0 +1,21 @@ +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from pprint import pprint + +test_files = { + "ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"), +} + +def test_rapidocrloader(): + img_path = test_files["ocr_test.jpg"] + from document_loaders import RapidOCRLoader + + loader = RapidOCRLoader(img_path) + docs = loader.load() + pprint(docs) + assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str) + + diff --git a/tests/kb_vector_db/__init__.py b/tests/kb_vector_db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/kb_vector_db/test_faiss_kb.py b/tests/kb_vector_db/test_faiss_kb.py new file mode 100644 index 00000000..9c329c8b --- /dev/null +++ b/tests/kb_vector_db/test_faiss_kb.py @@ -0,0 +1,35 @@ +from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService +from server.knowledge_base.migrate import create_tables +from server.knowledge_base.utils import KnowledgeFile + +kbService = FaissKBService("test") +test_kb_name = "test" +test_file_name = "README.md" +testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name) +search_content = "如何启动api服务" + + +def test_init(): + create_tables() + + +def test_create_db(): + assert kbService.create_kb() + + +def test_add_doc(): + assert kbService.add_doc(testKnowledgeFile) + + +def test_search_db(): + result = kbService.search_docs(search_content) + assert len(result) > 0 +def test_delete_doc(): + assert kbService.delete_doc(testKnowledgeFile) + + + + + +def test_delete_db(): + assert kbService.drop_kb() diff --git a/tests/kb_vector_db/test_milvus_db.py b/tests/kb_vector_db/test_milvus_db.py new file mode 100644 index 00000000..ed723806 --- /dev/null +++ b/tests/kb_vector_db/test_milvus_db.py @@ -0,0 +1,31 @@ +from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService +from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService +from server.knowledge_base.kb_service.pg_kb_service import PGKBService +from server.knowledge_base.migrate import create_tables +from server.knowledge_base.utils import KnowledgeFile + +kbService = MilvusKBService("test") + +test_kb_name = "test" +test_file_name = "README.md" +testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name) +search_content = "如何启动api服务" + +def test_init(): + create_tables() + + +def test_create_db(): + assert kbService.create_kb() + + +def test_add_doc(): + assert kbService.add_doc(testKnowledgeFile) + + +def test_search_db(): + result = kbService.search_docs(search_content) + assert len(result) > 0 +def test_delete_doc(): + assert kbService.delete_doc(testKnowledgeFile) + diff --git a/tests/kb_vector_db/test_pg_db.py b/tests/kb_vector_db/test_pg_db.py new file mode 100644 index 00000000..12448d05 --- /dev/null +++ b/tests/kb_vector_db/test_pg_db.py @@ -0,0 +1,31 @@ +from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService +from server.knowledge_base.kb_service.pg_kb_service import PGKBService +from server.knowledge_base.migrate import create_tables +from server.knowledge_base.utils import KnowledgeFile + +kbService = PGKBService("test") + +test_kb_name = "test" +test_file_name = "README.md" +testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name) +search_content = "如何启动api服务" + + +def test_init(): + create_tables() + + +def test_create_db(): + assert kbService.create_kb() + + +def test_add_doc(): + assert kbService.add_doc(testKnowledgeFile) + + +def test_search_db(): + result = kbService.search_docs(search_content) + assert len(result) > 0 +def test_delete_doc(): + assert kbService.delete_doc(testKnowledgeFile) + diff --git a/tests/samples/ocr_test.jpg b/tests/samples/ocr_test.jpg new file mode 100644 index 00000000..70c199b7 Binary files /dev/null and b/tests/samples/ocr_test.jpg differ diff --git a/tests/samples/ocr_test.pdf b/tests/samples/ocr_test.pdf new file mode 100644 index 00000000..3a137ad1 Binary files /dev/null and b/tests/samples/ocr_test.pdf differ diff --git a/webui.py b/webui.py index 58fc0e39..0cda9ebc 100644 --- a/webui.py +++ b/webui.py @@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu from webui_pages import * import os from configs import VERSION +from server.utils import api_address -api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) + +api = ApiRequest(base_url=api_address()) if __name__ == "__main__": st.set_page_config( diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 04ece7da..25b885b3 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -1,10 +1,14 @@ import streamlit as st +from configs.server_config import FSCHAT_MODEL_WORKERS from webui_pages.utils import * from streamlit_chatbox import * from datetime import datetime from server.chat.search_engine_chat import SEARCH_ENGINES -from typing import List, Dict import os +from configs.model_config import llm_model_dict, LLM_MODEL +from server.utils import get_model_worker_config +from typing import List, Dict + chat_box = ChatBox( assistant_avatar=os.path.join( @@ -59,9 +63,39 @@ def dialogue_page(api: ApiRequest): on_change=on_mode_change, key="dialogue_mode", ) - history_len = st.number_input("历史对话轮数:", 0, 10, 3) - # todo: support history len + def on_llm_change(): + st.session_state["prev_llm_model"] = llm_model + + def llm_model_format_func(x): + if x in running_models: + return f"{x} (Running)" + return x + + running_models = api.list_running_models() + config_models = api.list_config_models() + for x in running_models: + if x in config_models: + config_models.remove(x) + llm_models = running_models + config_models + if "prev_llm_model" not in st.session_state: + index = llm_models.index(LLM_MODEL) + else: + index = 0 + llm_model = st.selectbox("选择LLM模型:", + llm_models, + index, + format_func=llm_model_format_func, + on_change=on_llm_change, + # key="llm_model", + ) + if (st.session_state.get("prev_llm_model") != llm_model + and not get_model_worker_config(llm_model).get("online_api")): + with st.spinner(f"正在加载模型: {llm_model}"): + r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) + st.session_state["prev_llm_model"] = llm_model + + history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) def on_kb_change(): st.toast(f"已加载知识库: {st.session_state.selected_kb}") @@ -75,7 +109,7 @@ def dialogue_page(api: ApiRequest): on_change=on_kb_change, key="selected_kb", ) - kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3) + kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) # chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) @@ -87,13 +121,13 @@ def dialogue_page(api: ApiRequest): options=search_engine_list, index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0, ) - se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3) + se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K) # Display chat messages from history on app rerun chat_box.output_messages() - chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter " + chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter " if prompt := st.chat_input(chat_input_placeholder, key="prompt"): history = get_messages_history(history_len) @@ -101,7 +135,7 @@ def dialogue_page(api: ApiRequest): if dialogue_mode == "LLM 对话": chat_box.ai_say("正在思考...") text = "" - r = api.chat_chat(prompt, history) + r = api.chat_chat(prompt, history=history, model=llm_model) for t in r: if error_msg := check_error_msg(t): # check whether error occured st.error(error_msg) @@ -116,7 +150,7 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="知识库匹配结果"), ]) text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history): + for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) text += d["answer"] @@ -129,8 +163,8 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="网络搜索结果"), ]) text = "" - for d in api.search_engine_chat(prompt, search_engine, se_top_k): - if error_msg := check_error_msg(d): # check whether error occured + for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model): + if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) else: text += d["answer"] diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 4351e956..0889ca54 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -4,7 +4,7 @@ from st_aggrid import AgGrid, JsCode from st_aggrid.grid_options_builder import GridOptionsBuilder import pandas as pd from server.knowledge_base.utils import get_file_path, LOADER_DICT -from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details +from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE import os @@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest): # 上传文件 # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) - files = st.file_uploader("上传知识文件", + files = st.file_uploader("上传知识文件(暂不支持扫描PDF)", [i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True, ) @@ -152,7 +152,7 @@ def knowledge_base_page(api: ApiRequest): # 知识库详情 # st.info("请选择文件,点击按钮进行操作。") - doc_details = pd.DataFrame(get_kb_doc_details(kb)) + doc_details = pd.DataFrame(get_kb_file_details(kb)) if not len(doc_details): st.info(f"知识库 `{kb}` 中暂无文件") else: @@ -160,7 +160,7 @@ def knowledge_base_page(api: ApiRequest): st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作") doc_details.drop(columns=["kb_name"], inplace=True) doc_details = doc_details[[ - "No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db", + "No", "file_name", "document_loader", "docs_count", "in_folder", "in_db", ]] # doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") # doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") @@ -172,7 +172,8 @@ def knowledge_base_page(api: ApiRequest): # ("file_ext", "文档类型"): {}, # ("file_version", "文档版本"): {}, ("document_loader", "文档加载器"): {}, - ("text_splitter", "分词器"): {}, + ("docs_count", "文档数量"): {}, + # ("text_splitter", "分词器"): {}, # ("create_time", "创建时间"): {}, ("in_folder", "源文件"): {"cellRenderer": cell_renderer}, ("in_db", "向量库"): {"cellRenderer": cell_renderer}, @@ -244,7 +245,6 @@ def knowledge_base_page(api: ApiRequest): cols = st.columns(3) - # todo: freezed if cols[0].button( "依据源文件重建向量库", # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", @@ -258,7 +258,7 @@ def knowledge_base_page(api: ApiRequest): if msg := check_error_msg(d): st.toast(msg) else: - empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") + empty.progress(d["finished"] / d["total"], d["msg"]) st.experimental_rerun() if cols[2].button( diff --git a/webui_pages/utils.py b/webui_pages/utils.py index c666d458..08511044 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -6,11 +6,14 @@ from configs.model_config import ( DEFAULT_VS_TYPE, KB_ROOT_PATH, LLM_MODEL, + llm_model_dict, + HISTORY_LEN, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, logger, ) +from configs.server_config import HTTPX_DEFAULT_TIMEOUT import httpx import asyncio from server.chat.openai_chat import OpenAiChatMsgIn @@ -20,21 +23,12 @@ import json from io import BytesIO from server.db.repository.knowledge_base_repository import get_kb_detail from server.db.repository.knowledge_file_repository import get_file_detail -from server.utils import run_async, iter_over_async +from server.utils import run_async, iter_over_async, set_httpx_timeout from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path - - -def set_httpx_timeout(timeout=60.0): - ''' - 设置httpx默认timeout到60秒。 - httpx默认timeout是5秒,在请求LLM回答时不够用。 - ''' - httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout +from pprint import pprint KB_ROOT_PATH = Path(KB_ROOT_PATH) @@ -50,7 +44,7 @@ class ApiRequest: def __init__( self, base_url: str = "http://127.0.0.1:7861", - timeout: float = 60.0, + timeout: float = HTTPX_DEFAULT_TIMEOUT, no_remote_api: bool = False, # call api view function directly ): self.base_url = base_url @@ -224,9 +218,17 @@ class ApiRequest: try: with response as r: for chunk in r.iter_text(None): - if as_json and chunk: - yield json.loads(chunk) - elif chunk.strip(): + if not chunk: # fastchat api yield empty bytes on start and end + continue + if as_json: + try: + data = json.loads(chunk) + pprint(data, depth=1) + yield data + except Exception as e: + logger.error(f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。") + else: + print(chunk, end="", flush=True) yield chunk except httpx.ConnectError as e: msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。" @@ -274,6 +276,9 @@ class ApiRequest: return self._fastapi_stream2generator(response) else: data = msg.dict(exclude_unset=True, exclude_none=True) + print(f"received input message:") + pprint(data) + response = self.post( "/chat/fastchat", json=data, @@ -286,6 +291,7 @@ class ApiRequest: query: str, history: List[Dict] = [], stream: bool = True, + model: str = LLM_MODEL, no_remote_api: bool = None, ): ''' @@ -298,8 +304,12 @@ class ApiRequest: "query": query, "history": history, "stream": stream, + "model_name": model, } + print(f"received input message:") + pprint(data) + if no_remote_api: from server.chat.chat import chat response = chat(**data) @@ -316,6 +326,7 @@ class ApiRequest: score_threshold: float = SCORE_THRESHOLD, history: List[Dict] = [], stream: bool = True, + model: str = LLM_MODEL, no_remote_api: bool = None, ): ''' @@ -331,9 +342,13 @@ class ApiRequest: "score_threshold": score_threshold, "history": history, "stream": stream, + "model_name": model, "local_doc_url": no_remote_api, } + print(f"received input message:") + pprint(data) + if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat response = knowledge_base_chat(**data) @@ -352,6 +367,7 @@ class ApiRequest: search_engine_name: str, top_k: int = SEARCH_ENGINE_TOP_K, stream: bool = True, + model: str = LLM_MODEL, no_remote_api: bool = None, ): ''' @@ -365,8 +381,12 @@ class ApiRequest: "search_engine_name": search_engine_name, "top_k": top_k, "stream": stream, + "model_name": model, } + print(f"received input message:") + pprint(data) + if no_remote_api: from server.chat.search_engine_chat import search_engine_chat response = search_engine_chat(**data) @@ -473,18 +493,18 @@ class ApiRequest: no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/list_docs接口 + 对应api.py/knowledge_base/list_files接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api if no_remote_api: - from server.knowledge_base.kb_doc_api import list_docs - response = run_async(list_docs(knowledge_base_name)) + from server.knowledge_base.kb_doc_api import list_files + response = run_async(list_files(knowledge_base_name)) return response.data else: response = self.get( - "/knowledge_base/list_docs", + "/knowledge_base/list_files", params={"knowledge_base_name": knowledge_base_name} ) data = self._check_httpx_json_response(response) @@ -633,6 +653,84 @@ class ApiRequest: ) return self._httpx_stream2generator(response, as_json=True) + def list_running_models(self, controller_address: str = None): + ''' + 获取Fastchat中正运行的模型列表 + ''' + r = self.post( + "/llm_model/list_models", + ) + return r.json().get("data", []) + + def list_config_models(self): + ''' + 获取configs中配置的模型列表 + ''' + return list(llm_model_dict.keys()) + + def stop_llm_model( + self, + model_name: str, + controller_address: str = None, + ): + ''' + 停止某个LLM模型。 + 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 + ''' + data = { + "model_name": model_name, + "controller_address": controller_address, + } + r = self.post( + "/llm_model/stop", + json=data, + ) + return r.json() + + def change_llm_model( + self, + model_name: str, + new_model_name: str, + controller_address: str = None, + ): + ''' + 向fastchat controller请求切换LLM模型。 + ''' + if not model_name or not new_model_name: + return + + if new_model_name == model_name: + return { + "code": 200, + "msg": "什么都不用做" + } + + running_models = self.list_running_models() + if model_name not in running_models: + return { + "code": 500, + "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" + } + + config_models = self.list_config_models() + if new_model_name not in config_models: + return { + "code": 500, + "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" + } + + data = { + "model_name": model_name, + "new_model_name": new_model_name, + "controller_address": controller_address, + } + r = self.post( + "/llm_model/change", + json=data, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model + ) + return r.json() + def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: '''