Merge branch 'chatchat-space:dev' into dev

This commit is contained in:
liunux4odoo 2023-09-04 14:44:55 +08:00 committed by GitHub
commit 8475a5dfd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 799 additions and 123 deletions

View File

@ -216,12 +216,13 @@ embedding_model_dict = {
当前项目的知识库信息存储在数据库中,在正式运行项目之前请先初始化数据库(我们强烈建议您在执行操作前备份您的知识文件)。 当前项目的知识库信息存储在数据库中,在正式运行项目之前请先初始化数据库(我们强烈建议您在执行操作前备份您的知识文件)。
- 如果您是从 `0.1.x` 版本升级过来的用户针对已建立的知识库请确认知识库的向量库类型、Embedding 模型与 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可: - 如果您是从 `0.1.x` 版本升级过来的用户针对已建立的知识库请确认知识库的向量库类型、Embedding 模型与 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可:
```shell ```shell
$ python init_database.py $ python init_database.py
``` ```
- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,或者之前的向量库没有开启 `normalize_L2`,需要以下命令初始化或重建知识库:
- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,或者之前的向量库没有开启 `normalize_L2`,需要以下命令初始化或重建知识库:
```shell ```shell
$ python init_database.py --recreate-vs $ python init_database.py --recreate-vs
``` ```
@ -266,7 +267,7 @@ max_gpu_memory="20GiB"
⚠️ **注意:** ⚠️ **注意:**
**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;** **1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wsl;**
**2.加载非默认模型需要用命令行参数--model-path-address指定模型不会读取model_config.py配置;** **2.加载非默认模型需要用命令行参数--model-path-address指定模型不会读取model_config.py配置;**
@ -324,7 +325,7 @@ $ python server/api.py
启动 API 服务后,可访问 `localhost:7861``{API 所在服务器 IP}:7861` FastAPI 自动生成的 docs 进行接口查看与测试。 启动 API 服务后,可访问 `localhost:7861``{API 所在服务器 IP}:7861` FastAPI 自动生成的 docs 进行接口查看与测试。
- FastAPI docs 界面 - FastAPI docs 界面
![](img/fastapi_docs_020_0.png) ![](img/fastapi_docs_020_0.png)
#### 5.3 启动 Web UI 服务 #### 5.3 启动 Web UI 服务
@ -348,10 +349,11 @@ $ streamlit run webui.py --server.port 666
``` ```
- Web UI 对话界面: - Web UI 对话界面:
![](img/webui_0813_0.png) ![](img/webui_0813_0.png)
- Web UI 知识库管理页面:
- Web UI 知识库管理页面:
![](img/webui_0813_1.png) ![](img/webui_0813_1.png)
--- ---
@ -399,14 +401,14 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
## 路线图 ## 路线图
- [X] Langchain 应用 - [x] Langchain 应用
- [X] 本地数据接入 - [x] 本地数据接入
- [X] 接入非结构化文档 - [x] 接入非结构化文档
- [X] .md - [x] .md
- [X] .txt - [x] .txt
- [X] .docx - [x] .docx
- [ ] 结构化数据接入 - [ ] 结构化数据接入
- [X] .csv - [x] .csv
- [ ] .xlsx - [ ] .xlsx
- [ ] 分词及召回 - [ ] 分词及召回
- [ ] 接入不同类型 TextSplitter - [ ] 接入不同类型 TextSplitter
@ -415,24 +417,24 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
- [ ] 本地网页接入 - [ ] 本地网页接入
- [ ] SQL 接入 - [ ] SQL 接入
- [ ] 知识图谱/图数据库接入 - [ ] 知识图谱/图数据库接入
- [X] 搜索引擎接入 - [x] 搜索引擎接入
- [X] Bing 搜索 - [x] Bing 搜索
- [X] DuckDuckGo 搜索 - [x] DuckDuckGo 搜索
- [ ] Agent 实现 - [ ] Agent 实现
- [X] LLM 模型接入 - [x] LLM 模型接入
- [X] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm - [x] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
- [ ] 支持 ChatGLM API 等 LLM API 的接入 - [ ] 支持 ChatGLM API 等 LLM API 的接入
- [X] Embedding 模型接入 - [x] Embedding 模型接入
- [X] 支持调用 HuggingFace 中各开源 Emebdding 模型 - [x] 支持调用 HuggingFace 中各开源 Emebdding 模型
- [ ] 支持 OpenAI Embedding API 等 Embedding API 的接入 - [ ] 支持 OpenAI Embedding API 等 Embedding API 的接入
- [X] 基于 FastAPI 的 API 方式调用 - [x] 基于 FastAPI 的 API 方式调用
- [X] Web UI - [x] Web UI
- [X] 基于 Streamlit 的 Web UI - [x] 基于 Streamlit 的 Web UI
--- ---
## 项目交流群 ## 项目交流群
<img src="img/qr_code_56.jpg" alt="二维码" width="300" height="300" /> <img src="img/qr_code_30.jpg" alt="二维码" width="300" height="300" />
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

View File

@ -7,6 +7,19 @@ logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT) logging.basicConfig(format=LOG_FORMAT)
# 分布式部署时不运行LLM的机器上可以不装torch
def default_device():
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
# 在以下字典中修改属性值以指定本地embedding模型存储位置 # 在以下字典中修改属性值以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径 # 此处请写绝对路径
@ -34,7 +47,6 @@ EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto" EMBEDDING_DEVICE = "auto"
llm_model_dict = { llm_model_dict = {
"chatglm-6b": { "chatglm-6b": {
"local_model_path": "THUDM/chatglm-6b", "local_model_path": "THUDM/chatglm-6b",
@ -74,6 +86,16 @@ llm_model_dict = {
"api_key": os.environ.get("OPENAI_API_KEY"), "api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY") "openai_proxy": os.environ.get("OPENAI_PROXY")
}, },
# 线上模型。当前支持智谱AI。
# 如果没有设置有效的local_model_path则认为是在线模型API。
# 请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
"chatglm-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": os.environ.get("ZHIPUAI_API_KEY"),
"provider": "ChatGLMWorker",
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
},
} }
# LLM 名称 # LLM 名称
@ -82,7 +104,7 @@ LLM_MODEL = "chatglm2-6b"
# 历史对话轮数 # 历史对话轮数
HISTORY_LEN = 3 HISTORY_LEN = 3
# LLM 运行设备。可选项同Embedding 运行设备 # LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一
LLM_DEVICE = "auto" LLM_DEVICE = "auto"
# 日志存储路径 # 日志存储路径

View File

@ -1,4 +1,8 @@
from .model_config import LLM_MODEL, LLM_DEVICE from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
import httpx
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
HTTPX_DEFAULT_TIMEOUT = 300.0
# API 是否开启跨域默认为False如果需要开启请设置为True # API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain # is open cross domain
@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = {
# 这些模型必须是在model_config.llm_model_dict中正确配置的。 # 这些模型必须是在model_config.llm_model_dict中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL # 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = { FSCHAT_MODEL_WORKERS = {
LLM_MODEL: { # 所有模型共用的默认配置可在模型专项配置或llm_model_dict中进行覆盖。
"default": {
"host": DEFAULT_BIND_HOST, "host": DEFAULT_BIND_HOST,
"port": 20002, "port": 20002,
"device": LLM_DEVICE, "device": LLM_DEVICE,
# todo: 多卡加载需要配置的参数
# 多卡加载需要配置的参数
# "gpus": None, # 使用的GPU以str的格式指定如"0,1" # "gpus": None, # 使用的GPU以str的格式指定如"0,1"
# "num_gpus": 1, # 使用GPU的数量 # "num_gpus": 1, # 使用GPU的数量
# 以下为非常用参数,可根据需要配置
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存 # "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
# 以下为非常用参数,可根据需要配置
# "load_8bit": False, # 开启8bit量化 # "load_8bit": False, # 开启8bit量化
# "cpu_offloading": None, # "cpu_offloading": None,
# "gptq_ckpt": None, # "gptq_ckpt": None,
@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2, # "stream_interval": 2,
# "no_register": False, # "no_register": False,
}, },
"baichuan-7b": { # 使用default中的IP和端口
"device": "cpu",
},
"chatglm-api": { # 请为每个在线API设置不同的端口
"port": 20003,
},
} }
# fastchat multi model worker server # fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = { FSCHAT_MULTI_MODEL_WORKERS = {
# todo # TODO:
} }
# fastchat controller server # fastchat controller server

BIN
img/qr_code_30.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 269 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 272 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 284 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 281 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 269 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 291 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 200 KiB

View File

@ -4,11 +4,12 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
from configs import VERSION from configs import VERSION
import argparse import argparse
import uvicorn import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat, from server.chat import (chat, knowledge_base_chat, openai_chat,
@ -17,7 +18,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc, from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store, update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore) search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
import httpx
from typing import List from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -123,6 +125,75 @@ def create_app():
summary="根据content中文档重建向量库流式输出处理进度。" summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store) )(recreate_vector_store)
# LLM模型相关接口
@app.post("/llm_model/list_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型")
def list_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
return app return app

View File

@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
{"role": "assistant", "content": "虎头虎脑"}]] {"role": "assistant", "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
): ):
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
async def chat_iterator(query: str, async def chat_iterator(query: str,
history: List[History] = [], history: List[History] = [],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
streaming=True, streaming=True,
verbose=True, verbose=True,
callbacks=[callback], callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=LLM_MODEL, model_name=model_name,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") openai_proxy=llm_model_dict[model_name].get("openai_proxy")
) )
input_msg = History(role="user", content="{{ input }}").to_msg_template(False) input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
await task await task
return StreamingResponse(chat_iterator(query, history), return StreamingResponse(chat_iterator(query, history, model_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"content": "虎头虎脑"}]] "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None, request: Request = None,
): ):
@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
kb: KBService, kb: KBService,
top_k: int, top_k: int,
history: Optional[List[History]], history: Optional[List[History]],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI( model = ChatOpenAI(
streaming=True, streaming=True,
verbose=True, verbose=True,
callbacks=[callback], callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=LLM_MODEL, model_name=model_name,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") openai_proxy=llm_model_dict[model_name].get("openai_proxy")
) )
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
await task await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history), return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
"content": "虎头虎脑"}]] "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
): ):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str, search_engine_name: str,
top_k: int, top_k: int,
history: Optional[List[History]], history: Optional[List[History]],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI( model = ChatOpenAI(
streaming=True, streaming=True,
verbose=True, verbose=True,
callbacks=[callback], callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=LLM_MODEL, model_name=model_name,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") openai_proxy=llm_model_dict[model_name].get("openai_proxy")
) )
docs = lookup_search_engine(query, search_engine_name, top_k) docs = lookup_search_engine(query, search_engine_name, top_k)
@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
ensure_ascii=False) ensure_ascii=False)
await task await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -6,7 +6,8 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text from sqlalchemy import text
from configs.model_config import kbs_config from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile from server.knowledge_base.utils import load_embeddings, KnowledgeFile

View File

@ -0,0 +1 @@
from .zhipu import ChatGLMWorker

View File

@ -0,0 +1,71 @@
from configs.model_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
import threading
from typing import Dict, List
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiModelOutMsg(BaseModel):
error_code: int = 0
text: str
class ApiModelWorker(BaseModelWorker):
BASE_URL: str
SUPPORT_MODELS: List
def __init__(
self,
model_names: List[str],
controller_addr: str,
worker_addr: str,
context_len: int = 2048,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
print("count token")
print(params)
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params):
self.call_ct += 1
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def get_embeddings(self, params):
print("embedding")
print(params)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()

View File

@ -0,0 +1,75 @@
import zhipuai
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal
class ChatGLMWorker(ApiModelWorker):
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
def __init__(
self,
*,
model_names: List[str] = ["chatglm-api"],
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name="chatglm-api",
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["Human", "Assistant"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
# TODO: 支持stream参数维护request_id传过来的prompt也有问题
from server.utils import get_model_worker_config
super().generate_stream_gate(params)
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
response = zhipuai.model_api.sse_invoke(
model=self.version,
prompt=[{"role": "user", "content": params["prompt"]}],
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
)
for e in response.events():
if e.event == "add":
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
# TODO: 更健壮的消息处理
# elif e.event == "finish":
# ...
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)

View File

@ -5,13 +5,17 @@ import torch
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
from typing import Literal, Optional from configs.server_config import FSCHAT_MODEL_WORKERS
import os
from server import model_workers
from typing import Literal, Optional, Any
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code") code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message") msg: str = pydantic.Field("success", description="API status message")
data: Any = pydantic.Field(None, description="API data")
class Config: class Config:
schema_extra = { schema_extra = {
@ -201,10 +205,29 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {})) config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
config["device"] = llm_device(config.get("device"))
# 如果没有设置有效的local_model_path则认为是在线模型API
if not os.path.isdir(config.get("local_model_path", "")):
config["online_api"] = True
if provider := config.get("provider"):
try:
config["worker_class"] = getattr(model_workers, provider)
except Exception as e:
print(f"在线模型 {model_name} 的provider没有正确配置")
config["device"] = llm_device(config.get("device") or LLM_DEVICE)
return config return config
def get_all_model_worker_configs() -> dict:
result = {}
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
for name in model_names:
if name != "default":
result[name] = get_model_worker_config(name)
return result
def fschat_controller_address() -> str: def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER from configs.server_config import FSCHAT_CONTROLLER

View File

@ -1,6 +1,7 @@
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
import multiprocessing as mp import multiprocessing as mp
import subprocess import subprocess
import asyncio
import sys import sys
import os import os
from pprint import pprint from pprint import pprint
@ -8,6 +9,7 @@ from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数 # 设置numexpr最大线程数默认为CPU核心数
try: try:
import numexpr import numexpr
n_cores = numexpr.utils.detect_number_of_cores() n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores) os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except: except:
@ -16,14 +18,14 @@ except:
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, ) FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout, fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config) get_model_worker_config, get_all_model_worker_configs,
from server.utils import MakeFastAPIOffline, FastAPI MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse import argparse
from typing import Tuple, List from typing import Tuple, List, Dict
from configs import VERSION from configs import VERSION
@ -41,6 +43,7 @@ def create_controller_app(
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
app.title = "FastChat Controller" app.title = "FastChat Controller"
app._controller = controller
return app return app
@ -97,43 +100,62 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse
) )
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig( # 在线模型API
ckpt=args.gptq_ckpt or args.model_path, if worker_class := kwargs.get("worker_class"):
wbits=args.gptq_wbits, worker = worker_class(model_names=args.model_names,
groupsize=args.gptq_groupsize, controller_addr=args.controller_address,
act_order=args.gptq_act_order, worker_addr=args.worker_address)
) # 本地模型
awq_config = AWQConfig( else:
ckpt=args.awq_ckpt or args.model_path, # workaround to make program exit with Ctrl+c
wbits=args.awq_wbits, # it should be deleted after pr is merged by fastchat
groupsize=args.awq_groupsize, def _new_init_heart_beat(self):
) self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
worker = ModelWorker( ModelWorker.init_heart_beat = _new_init_heart_beat
controller_addr=args.controller_address,
worker_addr=args.worker_address, gptq_config = GptqConfig(
worker_id=worker_id, ckpt=args.gptq_ckpt or args.model_path,
model_path=args.model_path, wbits=args.gptq_wbits,
model_names=args.model_names, groupsize=args.gptq_groupsize,
limit_worker_concurrency=args.limit_worker_concurrency, act_order=args.gptq_act_order,
no_register=args.no_register, )
device=args.device, awq_config = AWQConfig(
num_gpus=args.num_gpus, ckpt=args.awq_ckpt or args.model_path,
max_gpu_memory=args.max_gpu_memory, wbits=args.awq_wbits,
load_8bit=args.load_8bit, groupsize=args.awq_groupsize,
cpu_offloading=args.cpu_offloading, )
gptq_config=gptq_config,
awq_config=awq_config, worker = ModelWorker(
stream_interval=args.stream_interval, controller_addr=args.controller_address,
conv_template=args.conv_template, worker_addr=args.worker_address,
) worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})" app.title = f"FastChat LLM Server ({args.model_names[0]})"
app._worker = worker
return app return app
@ -188,8 +210,11 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
q.put(run_seq) q.put(run_seq)
def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
import uvicorn import uvicorn
import httpx
from fastapi import Body
import time
import sys import sys
app = create_controller_app( app = create_controller_app(
@ -198,11 +223,71 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
) )
_set_app_seq(app, q, run_seq) _set_app_seq(app, q, run_seq)
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
available_models = app._controller.list_models()
if new_model_name in available_models:
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
logger.info(msg)
return {"code": 500, "msg": msg}
if new_model_name:
logger.info(f"开始切换LLM模型{model_name}{new_model_name}")
else:
logger.info(f"即将停止LLM模型 {model_name}")
if model_name not in available_models:
msg = f"the model {model_name} is not available"
logger.error(msg)
return {"code": 500, "msg": msg}
worker_address = app._controller.get_worker_address(model_name)
if not worker_address:
msg = f"can not find model_worker address for {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
r = httpx.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = 300 # wait 5 minutes for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
break
time.sleep(1)
timer -= 1
if timer > 0:
msg = f"sucess change model from {model_name} to {new_model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
else:
msg = f"failed change model from {model_name} to {new_model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
else:
msg = f"sucess to release model: {model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
host = FSCHAT_CONTROLLER["host"] host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"] port = FSCHAT_CONTROLLER["port"]
if log_level == "ERROR": if log_level == "ERROR":
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
@ -211,19 +296,20 @@ def run_model_worker(
controller_address: str = "", controller_address: str = "",
q: Queue = None, q: Queue = None,
run_seq: int = 2, run_seq: int = 2,
log_level: str ="INFO", log_level: str = "INFO",
): ):
import uvicorn import uvicorn
from fastapi import Body
import sys import sys
kwargs = get_model_worker_config(model_name) kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host") host = kwargs.pop("host")
port = kwargs.pop("port") port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "")
kwargs["model_path"] = model_path
kwargs["model_names"] = [model_name] kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address() kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address() kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("local_model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs) app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq) _set_app_seq(app, q, run_seq)
@ -231,6 +317,22 @@ def run_model_worker(
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put(["start", new_model_name])
else:
if new_model_name:
q.put(["replace", new_model_name])
else:
q.put(["stop"])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
@ -337,6 +439,13 @@ def parse_args() -> argparse.ArgumentParser:
help="run api.py server", help="run api.py server",
dest="api", dest="api",
) )
parser.add_argument(
"-p",
"--api-worker",
action="store_true",
help="run online model api such as zhipuai",
dest="api_worker",
)
parser.add_argument( parser.add_argument(
"-w", "-w",
"--webui", "--webui",
@ -368,9 +477,14 @@ def dump_server_info(after_start=False, args=None):
print(f"项目版本:{VERSION}") print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}") print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n") print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {llm_device()}")
pprint(llm_model_dict[LLM_MODEL]) model = LLM_MODEL
if args and args.model_name:
model = args.model_name
print(f"当前LLM模型{model} @ {llm_device()}")
pprint(llm_model_dict[model])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}") print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start: if after_start:
print("\n") print("\n")
print(f"服务端运行信息:") print(f"服务端运行信息:")
@ -385,10 +499,15 @@ def dump_server_info(after_start=False, args=None):
print("\n") print("\n")
if __name__ == "__main__": async def start_main_server():
import time import time
mp.set_start_method("spawn") mp.set_start_method("spawn")
# TODO 链式启动的队列,确实可以用于控制启动顺序,
# 但目前引入proxy_worker后启动的独立于框架的work processes无法确认当前的位置
# 导致注册器未启动时,无法注册。整个启动链因为异常被终止
# 使用await asyncio.sleep(3)可以让后续代码等待一段时间,但不是最优解
queue = Queue() queue = Queue()
args, parser = parse_args() args, parser = parse_args()
@ -396,17 +515,20 @@ if __name__ == "__main__":
args.openai_api = True args.openai_api = True
args.model_worker = True args.model_worker = True
args.api = True args.api = True
args.api_worker = True
args.webui = True args.webui = True
elif args.all_api: elif args.all_api:
args.openai_api = True args.openai_api = True
args.model_worker = True args.model_worker = True
args.api = True args.api = True
args.api_worker = True
args.webui = False args.webui = False
elif args.llm_api: elif args.llm_api:
args.openai_api = True args.openai_api = True
args.model_worker = True args.model_worker = True
args.api_worker = True
args.api = False args.api = False
args.webui = False args.webui = False
@ -416,7 +538,11 @@ if __name__ == "__main__":
logger.info(f"正在启动服务:") logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {} processes = {"online-api": []}
def process_count():
return len(processes) + len(processes["online-api"]) - 1
if args.quiet: if args.quiet:
log_level = "ERROR" log_level = "ERROR"
else: else:
@ -426,38 +552,52 @@ if __name__ == "__main__":
process = Process( process = Process(
target=run_controller, target=run_controller,
name=f"controller({os.getpid()})", name=f"controller({os.getpid()})",
args=(queue, len(processes) + 1, log_level), args=(queue, process_count() + 1, log_level),
daemon=True, daemon=True,
) )
process.start() process.start()
await asyncio.sleep(3)
processes["controller"] = process processes["controller"] = process
process = Process( process = Process(
target=run_openai_api, target=run_openai_api,
name=f"openai_api({os.getpid()})", name=f"openai_api({os.getpid()})",
args=(queue, len(processes) + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start() process.start()
processes["openai_api"] = process processes["openai_api"] = process
if args.model_worker: if args.model_worker:
model_path = llm_model_dict[args.model_name].get("local_model_path", "") config = get_model_worker_config(args.model_name)
if os.path.isdir(model_path): if not config.get("online_api"):
process = Process( process = Process(
target=run_model_worker, target=run_model_worker,
name=f"model_worker({os.getpid()})", name=f"model_worker - {args.model_name} ({os.getpid()})",
args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level), args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True, daemon=True,
) )
process.start() process.start()
processes["model_worker"] = process processes["model_worker"] = process
if args.api_worker:
configs = get_all_model_worker_configs()
for model_name, config in configs.items():
if config.get("online_api") and config.get("worker_class"):
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name} ({os.getpid()})",
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
process.start()
processes["online-api"].append(process)
if args.api: if args.api:
process = Process( process = Process(
target=run_api_server, target=run_api_server,
name=f"API Server{os.getpid()})", name=f"API Server{os.getpid()})",
args=(queue, len(processes) + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start() process.start()
@ -467,39 +607,53 @@ if __name__ == "__main__":
process = Process( process = Process(
target=run_webui, target=run_webui,
name=f"WEBUI Server{os.getpid()})", name=f"WEBUI Server{os.getpid()})",
args=(queue, len(processes) + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start() process.start()
processes["webui"] = process processes["webui"] = process
if len(processes) == 0: if process_count() == 0:
parser.print_help() parser.print_help()
else: else:
try: try:
# log infors
while True: while True:
no = queue.get() no = queue.get()
if no == len(processes): if no == process_count():
time.sleep(0.5) time.sleep(0.5)
dump_server_info(after_start=True, args=args) dump_server_info(after_start=True, args=args)
break break
else: else:
queue.put(no) queue.put(no)
if model_worker_process := processes.get("model_worker"): if model_worker_process := processes.pop("model_worker", None):
model_worker_process.join() model_worker_process.join()
for process in processes.pop("online-api", []):
process.join()
for name, process in processes.items(): for name, process in processes.items():
if name != "model_worker": process.join()
process.join()
except: except:
if model_worker_process := processes.get("model_worker"): if model_worker_process := processes.pop("model_worker", None):
model_worker_process.terminate() model_worker_process.terminate()
for process in processes.pop("online-api", []):
process.terminate()
for name, process in processes.items(): for name, process in processes.items():
if name != "model_worker": process.terminate()
process.terminate()
if __name__ == "__main__":
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 同步调用协程代码
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例: # 服务启动后接口调用示例:
# import openai # import openai
# openai.api_key = "EMPTY" # Not support yet # openai.api_key = "EMPTY" # Not support yet

View File

@ -1,4 +1,3 @@
from doctest import testfile
import requests import requests
import json import json
import sys import sys

74
tests/api/test_llm_api.py Normal file
View File

@ -0,0 +1,74 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict
from pprint import pprint
import random
def get_configured_models():
model_workers = list(FSCHAT_MODEL_WORKERS)
if "default" in model_workers:
model_workers.remove("default")
llm_dict = list(llm_model_dict)
return model_workers, llm_dict
api_base_url = api_address()
def get_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
if r.status_code == 200:
return r.json()["data"]
return []
def test_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
assert r.status_code == 200
print("\n获取当前正在运行的模型列表:")
pprint(r.json())
assert isinstance(r.json()["data"], list)
assert len(r.json()["data"]) > 0
# 不建议使用stop_model功能。按现在的实现停止了就只能手动再启动
# def test_stop_model(api="/llm_model/stop"):
# url = api_base_url + api
# r = requests.post(url, json={""})
def test_change_model(api="/llm_model/change"):
url = api_base_url + api
running_models = get_running_models()
assert len(running_models) > 0
model_workers, llm_dict = get_configured_models()
availabel_new_models = set(model_workers) - set(running_models)
if len(availabel_new_models) == 0:
availabel_new_models = set(llm_dict) - set(running_models)
availabel_new_models = list(availabel_new_models)
assert len(availabel_new_models) > 0
print(availabel_new_models)
model_name = random.choice(running_models)
new_model_name = random.choice(availabel_new_models)
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
assert r.status_code == 200
running_models = get_running_models()
assert new_model_name in running_models

View File

@ -0,0 +1,21 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from pprint import pprint
test_files = {
"ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
}
def test_rapidocrpdfloader():
pdf_path = test_files["ocr_test.pdf"]
from document_loaders import RapidOCRPDFLoader
loader = RapidOCRPDFLoader(pdf_path)
docs = loader.load()
pprint(docs)
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)

View File

@ -0,0 +1,21 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from pprint import pprint
test_files = {
"ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
}
def test_rapidocrloader():
img_path = test_files["ocr_test.jpg"]
from document_loaders import RapidOCRLoader
loader = RapidOCRLoader(img_path)
docs = loader.load()
pprint(docs)
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)

View File

@ -1,10 +1,14 @@
import streamlit as st import streamlit as st
from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_chatbox import * from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES from server.chat.search_engine_chat import SEARCH_ENGINES
from typing import List, Dict
import os import os
from configs.model_config import llm_model_dict, LLM_MODEL
from server.utils import get_model_worker_config
from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
@ -59,6 +63,38 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change, on_change=on_mode_change,
key="dialogue_mode", key="dialogue_mode",
) )
def on_llm_change():
st.session_state["prev_llm_model"] = llm_model
def llm_model_format_func(x):
if x in running_models:
return f"{x} (Running)"
return x
running_models = api.list_running_models()
config_models = api.list_config_models()
for x in running_models:
if x in config_models:
config_models.remove(x)
llm_models = running_models + config_models
if "prev_llm_model" not in st.session_state:
index = llm_models.index(LLM_MODEL)
else:
index = 0
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
# key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")):
with st.spinner(f"正在加载模型: {llm_model}"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["prev_llm_model"] = llm_model
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change(): def on_kb_change():
@ -99,7 +135,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "LLM 对话": if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...") chat_box.ai_say("正在思考...")
text = "" text = ""
r = api.chat_chat(prompt, history) r = api.chat_chat(prompt, history=history, model=llm_model)
for t in r: for t in r:
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg) st.error(error_msg)
@ -114,7 +150,7 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"), Markdown("...", in_expander=True, title="知识库匹配结果"),
]) ])
text = "" text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history): for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
text += d["answer"] text += d["answer"]
@ -127,8 +163,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="网络搜索结果"), Markdown("...", in_expander=True, title="网络搜索结果"),
]) ])
text = "" text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k): for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
else: else:
text += d["answer"] text += d["answer"]

View File

@ -6,12 +6,14 @@ from configs.model_config import (
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
KB_ROOT_PATH, KB_ROOT_PATH,
LLM_MODEL, LLM_MODEL,
llm_model_dict,
HISTORY_LEN, HISTORY_LEN,
SCORE_THRESHOLD, SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
logger, logger,
) )
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx import httpx
import asyncio import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn from server.chat.openai_chat import OpenAiChatMsgIn
@ -42,7 +44,7 @@ class ApiRequest:
def __init__( def __init__(
self, self,
base_url: str = "http://127.0.0.1:7861", base_url: str = "http://127.0.0.1:7861",
timeout: float = 60.0, timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly no_remote_api: bool = False, # call api view function directly
): ):
self.base_url = base_url self.base_url = base_url
@ -289,6 +291,7 @@ class ApiRequest:
query: str, query: str,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -301,6 +304,7 @@ class ApiRequest:
"query": query, "query": query,
"history": history, "history": history,
"stream": stream, "stream": stream,
"model_name": model,
} }
print(f"received input message:") print(f"received input message:")
@ -322,6 +326,7 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -337,6 +342,7 @@ class ApiRequest:
"score_threshold": score_threshold, "score_threshold": score_threshold,
"history": history, "history": history,
"stream": stream, "stream": stream,
"model_name": model,
"local_doc_url": no_remote_api, "local_doc_url": no_remote_api,
} }
@ -361,6 +367,7 @@ class ApiRequest:
search_engine_name: str, search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K, top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -374,6 +381,7 @@ class ApiRequest:
"search_engine_name": search_engine_name, "search_engine_name": search_engine_name,
"top_k": top_k, "top_k": top_k,
"stream": stream, "stream": stream,
"model_name": model,
} }
print(f"received input message:") print(f"received input message:")
@ -645,6 +653,84 @@ class ApiRequest:
) )
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
def list_running_models(self, controller_address: str = None):
'''
获取Fastchat中正运行的模型列表
'''
r = self.post(
"/llm_model/list_models",
)
return r.json().get("data", [])
def list_config_models(self):
'''
获取configs中配置的模型列表
'''
return list(llm_model_dict.keys())
def stop_llm_model(
self,
model_name: str,
controller_address: str = None,
):
'''
停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
data = {
"model_name": model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
):
'''
向fastchat controller请求切换LLM模型
'''
if not model_name or not new_model_name:
return
if new_model_name == model_name:
return {
"code": 200,
"msg": "什么都不用做"
}
running_models = self.list_running_models()
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = self.list_config_models()
if new_model_name not in config_models:
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
''' '''