diff --git a/README.md b/README.md
index c5c2d687..e79a30b1 100644
--- a/README.md
+++ b/README.md
@@ -216,12 +216,13 @@ embedding_model_dict = {
当前项目的知识库信息存储在数据库中,在正式运行项目之前请先初始化数据库(我们强烈建议您在执行操作前备份您的知识文件)。
- 如果您是从 `0.1.x` 版本升级过来的用户,针对已建立的知识库,请确认知识库的向量库类型、Embedding 模型与 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可:
-
+
```shell
$ python init_database.py
```
-- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,或者之前的向量库没有开启 `normalize_L2`,需要以下命令初始化或重建知识库:
+- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,或者之前的向量库没有开启 `normalize_L2`,需要以下命令初始化或重建知识库:
+
```shell
$ python init_database.py --recreate-vs
```
@@ -266,7 +267,7 @@ max_gpu_memory="20GiB"
⚠️ **注意:**
-**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;**
+**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wsl;**
**2.加载非默认模型需要用命令行参数--model-path-address指定模型,不会读取model_config.py配置;**
@@ -324,7 +325,7 @@ $ python server/api.py
启动 API 服务后,可访问 `localhost:7861` 或 `{API 所在服务器 IP}:7861` FastAPI 自动生成的 docs 进行接口查看与测试。
- FastAPI docs 界面
-
+

#### 5.3 启动 Web UI 服务
@@ -348,10 +349,11 @@ $ streamlit run webui.py --server.port 666
```
- Web UI 对话界面:
-
+

-- Web UI 知识库管理页面:
+- Web UI 知识库管理页面:
+

---
@@ -399,14 +401,14 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
## 路线图
-- [X] Langchain 应用
- - [X] 本地数据接入
- - [X] 接入非结构化文档
- - [X] .md
- - [X] .txt
- - [X] .docx
+- [x] Langchain 应用
+ - [x] 本地数据接入
+ - [x] 接入非结构化文档
+ - [x] .md
+ - [x] .txt
+ - [x] .docx
- [ ] 结构化数据接入
- - [X] .csv
+ - [x] .csv
- [ ] .xlsx
- [ ] 分词及召回
- [ ] 接入不同类型 TextSplitter
@@ -415,24 +417,24 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
- [ ] 本地网页接入
- [ ] SQL 接入
- [ ] 知识图谱/图数据库接入
- - [X] 搜索引擎接入
- - [X] Bing 搜索
- - [X] DuckDuckGo 搜索
+ - [x] 搜索引擎接入
+ - [x] Bing 搜索
+ - [x] DuckDuckGo 搜索
- [ ] Agent 实现
-- [X] LLM 模型接入
- - [X] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
+- [x] LLM 模型接入
+ - [x] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
- [ ] 支持 ChatGLM API 等 LLM API 的接入
-- [X] Embedding 模型接入
- - [X] 支持调用 HuggingFace 中各开源 Emebdding 模型
+- [x] Embedding 模型接入
+ - [x] 支持调用 HuggingFace 中各开源 Emebdding 模型
- [ ] 支持 OpenAI Embedding API 等 Embedding API 的接入
-- [X] 基于 FastAPI 的 API 方式调用
-- [X] Web UI
- - [X] 基于 Streamlit 的 Web UI
+- [x] 基于 FastAPI 的 API 方式调用
+- [x] Web UI
+ - [x] 基于 Streamlit 的 Web UI
---
## 项目交流群
-
+
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index 53097db3..1b54e837 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -7,6 +7,19 @@ logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
+# 分布式部署时,不运行LLM的机器上可以不装torch
+def default_device():
+ try:
+ import torch
+ if torch.cuda.is_available():
+ return "cuda"
+ if torch.backends.mps.is_available():
+ return "mps"
+ except:
+ pass
+ return "cpu"
+
+
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径
@@ -34,7 +47,6 @@ EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto"
-
llm_model_dict = {
"chatglm-6b": {
"local_model_path": "THUDM/chatglm-6b",
@@ -74,6 +86,16 @@ llm_model_dict = {
"api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY")
},
+ # 线上模型。当前支持智谱AI。
+ # 如果没有设置有效的local_model_path,则认为是在线模型API。
+ # 请在server_config中为每个在线API设置不同的端口
+ # 具体注册及api key获取请前往 http://open.bigmodel.cn
+ "chatglm-api": {
+ "api_base_url": "http://127.0.0.1:8888/v1",
+ "api_key": os.environ.get("ZHIPUAI_API_KEY"),
+ "provider": "ChatGLMWorker",
+ "version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
+ },
}
# LLM 名称
@@ -82,7 +104,7 @@ LLM_MODEL = "chatglm2-6b"
# 历史对话轮数
HISTORY_LEN = 3
-# LLM 运行设备。可选项同Embedding 运行设备。
+# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
LLM_DEVICE = "auto"
# 日志存储路径
diff --git a/configs/server_config.py.example b/configs/server_config.py.example
index 00f94ea5..ad731e37 100644
--- a/configs/server_config.py.example
+++ b/configs/server_config.py.example
@@ -1,4 +1,8 @@
-from .model_config import LLM_MODEL, LLM_DEVICE
+from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
+import httpx
+
+# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
+HTTPX_DEFAULT_TIMEOUT = 300.0
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
# is open cross domain
@@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = {
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = {
- LLM_MODEL: {
+ # 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。
+ "default": {
"host": DEFAULT_BIND_HOST,
"port": 20002,
"device": LLM_DEVICE,
- # todo: 多卡加载需要配置的参数
+
+ # 多卡加载需要配置的参数
# "gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
# "num_gpus": 1, # 使用GPU的数量
- # 以下为非常用参数,可根据需要配置
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
+
+ # 以下为非常用参数,可根据需要配置
# "load_8bit": False, # 开启8bit量化
# "cpu_offloading": None,
# "gptq_ckpt": None,
@@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2,
# "no_register": False,
},
+ "baichuan-7b": { # 使用default中的IP和端口
+ "device": "cpu",
+ },
+ "chatglm-api": { # 请为每个在线API设置不同的端口
+ "port": 20003,
+ },
}
# fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = {
- # todo
+ # TODO:
}
# fastchat controller server
diff --git a/img/qr_code_30.jpg b/img/qr_code_30.jpg
new file mode 100644
index 00000000..05d26467
Binary files /dev/null and b/img/qr_code_30.jpg differ
diff --git a/img/qr_code_50.jpg b/img/qr_code_50.jpg
deleted file mode 100644
index c0ae20f6..00000000
Binary files a/img/qr_code_50.jpg and /dev/null differ
diff --git a/img/qr_code_51.jpg b/img/qr_code_51.jpg
deleted file mode 100644
index f0993322..00000000
Binary files a/img/qr_code_51.jpg and /dev/null differ
diff --git a/img/qr_code_52.jpg b/img/qr_code_52.jpg
deleted file mode 100644
index 18793d56..00000000
Binary files a/img/qr_code_52.jpg and /dev/null differ
diff --git a/img/qr_code_53.jpg b/img/qr_code_53.jpg
deleted file mode 100644
index 3174ccc1..00000000
Binary files a/img/qr_code_53.jpg and /dev/null differ
diff --git a/img/qr_code_54.jpg b/img/qr_code_54.jpg
deleted file mode 100644
index 1245a164..00000000
Binary files a/img/qr_code_54.jpg and /dev/null differ
diff --git a/img/qr_code_55.jpg b/img/qr_code_55.jpg
deleted file mode 100644
index 8ff046c9..00000000
Binary files a/img/qr_code_55.jpg and /dev/null differ
diff --git a/img/qr_code_56.jpg b/img/qr_code_56.jpg
deleted file mode 100644
index f17458d2..00000000
Binary files a/img/qr_code_56.jpg and /dev/null differ
diff --git a/server/api.py b/server/api.py
index fe5e156e..37954b7f 100644
--- a/server/api.py
+++ b/server/api.py
@@ -4,11 +4,12 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-from configs.model_config import NLTK_DATA_PATH
-from configs.server_config import OPEN_CROSS_DOMAIN
+from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
+from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
from configs import VERSION
import argparse
import uvicorn
+from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
@@ -17,7 +18,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
-from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
+from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
+import httpx
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@@ -123,6 +125,75 @@ def create_app():
summary="根据content中文档重建向量库,流式输出处理进度。"
)(recreate_vector_store)
+ # LLM模型相关接口
+ @app.post("/llm_model/list_models",
+ tags=["LLM Model Management"],
+ summary="列出当前已加载的模型")
+ def list_models(
+ controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
+ ) -> BaseResponse:
+ '''
+ 从fastchat controller获取已加载模型列表
+ '''
+ try:
+ controller_address = controller_address or fschat_controller_address()
+ r = httpx.post(controller_address + "/list_models")
+ return BaseResponse(data=r.json()["models"])
+ except Exception as e:
+ return BaseResponse(
+ code=500,
+ data=[],
+ msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
+
+ @app.post("/llm_model/stop",
+ tags=["LLM Model Management"],
+ summary="停止指定的LLM模型(Model Worker)",
+ )
+ def stop_llm_model(
+ model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
+ controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
+ ) -> BaseResponse:
+ '''
+ 向fastchat controller请求停止某个LLM模型。
+ 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
+ '''
+ try:
+ controller_address = controller_address or fschat_controller_address()
+ r = httpx.post(
+ controller_address + "/release_worker",
+ json={"model_name": model_name},
+ )
+ return r.json()
+ except Exception as e:
+ return BaseResponse(
+ code=500,
+ msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
+
+ @app.post("/llm_model/change",
+ tags=["LLM Model Management"],
+ summary="切换指定的LLM模型(Model Worker)",
+ )
+ def change_llm_model(
+ model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
+ new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
+ controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
+ ):
+ '''
+ 向fastchat controller请求切换LLM模型。
+ '''
+ try:
+ controller_address = controller_address or fschat_controller_address()
+ r = httpx.post(
+ controller_address + "/release_worker",
+ json={"model_name": model_name, "new_model_name": new_model_name},
+ timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
+ )
+ return r.json()
+ except Exception as e:
+ return BaseResponse(
+ code=500,
+ msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
+
return app
diff --git a/server/chat/chat.py b/server/chat/chat.py
index 2e939f1d..ba23a5a1 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
+ model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
):
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
+ model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
@@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
streaming=True,
verbose=True,
callbacks=[callback],
- openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
- openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL,
- openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
+ openai_api_key=llm_model_dict[model_name]["api_key"],
+ openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model_name=model_name,
+ openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
@@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
await task
- return StreamingResponse(chat_iterator(query, history),
+ return StreamingResponse(chat_iterator(query, history, model_name),
media_type="text/event-stream")
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
index 27745691..69ec25dd 100644
--- a/server/chat/knowledge_base_chat.py
+++ b/server/chat/knowledge_base_chat.py
@@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
+ model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
):
@@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
kb: KBService,
top_k: int,
history: Optional[List[History]],
+ model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
- openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
- openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL,
- openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
+ openai_api_key=llm_model_dict[model_name]["api_key"],
+ openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model_name=model_name,
+ openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
@@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
await task
- return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
+ return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
media_type="text/event-stream")
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
index 8a2633bd..8fe7dae6 100644
--- a/server/chat/search_engine_chat.py
+++ b/server/chat/search_engine_chat.py
@@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
+ model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
+ model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
- openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
- openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL,
- openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
+ openai_api_key=llm_model_dict[model_name]["api_key"],
+ openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model_name=model_name,
+ openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = lookup_search_engine(query, search_engine_name, top_k)
@@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
ensure_ascii=False)
await task
- return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
+ return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
media_type="text/event-stream")
diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py
index 8e05b426..afe9f450 100644
--- a/server/knowledge_base/kb_service/pg_kb_service.py
+++ b/server/knowledge_base/kb_service/pg_kb_service.py
@@ -6,7 +6,8 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
-from configs.model_config import kbs_config
+from configs.model_config import EMBEDDING_DEVICE, kbs_config
+
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py
new file mode 100644
index 00000000..932c2f3b
--- /dev/null
+++ b/server/model_workers/__init__.py
@@ -0,0 +1 @@
+from .zhipu import ChatGLMWorker
diff --git a/server/model_workers/base.py b/server/model_workers/base.py
new file mode 100644
index 00000000..b72f6839
--- /dev/null
+++ b/server/model_workers/base.py
@@ -0,0 +1,71 @@
+from configs.model_config import LOG_PATH
+import fastchat.constants
+fastchat.constants.LOGDIR = LOG_PATH
+from fastchat.serve.model_worker import BaseModelWorker
+import uuid
+import json
+import sys
+from pydantic import BaseModel
+import fastchat
+import threading
+from typing import Dict, List
+
+
+# 恢复被fastchat覆盖的标准输出
+sys.stdout = sys.__stdout__
+sys.stderr = sys.__stderr__
+
+
+class ApiModelOutMsg(BaseModel):
+ error_code: int = 0
+ text: str
+
+class ApiModelWorker(BaseModelWorker):
+ BASE_URL: str
+ SUPPORT_MODELS: List
+
+ def __init__(
+ self,
+ model_names: List[str],
+ controller_addr: str,
+ worker_addr: str,
+ context_len: int = 2048,
+ **kwargs,
+ ):
+ kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
+ kwargs.setdefault("model_path", "")
+ kwargs.setdefault("limit_worker_concurrency", 5)
+ super().__init__(model_names=model_names,
+ controller_addr=controller_addr,
+ worker_addr=worker_addr,
+ **kwargs)
+ self.context_len = context_len
+ self.init_heart_beat()
+
+ def count_token(self, params):
+ # TODO:需要完善
+ print("count token")
+ print(params)
+ prompt = params["prompt"]
+ return {"count": len(str(prompt)), "error_code": 0}
+
+ def generate_stream_gate(self, params):
+ self.call_ct += 1
+
+ def generate_gate(self, params):
+ for x in self.generate_stream_gate(params):
+ pass
+ return json.loads(x[:-1].decode())
+
+ def get_embeddings(self, params):
+ print("embedding")
+ print(params)
+
+ # workaround to make program exit with Ctrl+c
+ # it should be deleted after pr is merged by fastchat
+ def init_heart_beat(self):
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
+ )
+ self.heart_beat_thread.start()
diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py
new file mode 100644
index 00000000..4e4e15e0
--- /dev/null
+++ b/server/model_workers/zhipu.py
@@ -0,0 +1,75 @@
+import zhipuai
+from server.model_workers.base import ApiModelWorker
+from fastchat import conversation as conv
+import sys
+import json
+from typing import List, Literal
+
+
+class ChatGLMWorker(ApiModelWorker):
+ BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
+ SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
+
+ def __init__(
+ self,
+ *,
+ model_names: List[str] = ["chatglm-api"],
+ version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
+ controller_addr: str,
+ worker_addr: str,
+ **kwargs,
+ ):
+ kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
+ kwargs.setdefault("context_len", 32768)
+ super().__init__(**kwargs)
+ self.version = version
+
+ # 这里的是chatglm api的模板,其它API的conv_template需要定制
+ self.conv = conv.Conversation(
+ name="chatglm-api",
+ system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
+ messages=[],
+ roles=["Human", "Assistant"],
+ sep="\n### ",
+ stop_str="###",
+ )
+
+ def generate_stream_gate(self, params):
+ # TODO: 支持stream参数,维护request_id,传过来的prompt也有问题
+ from server.utils import get_model_worker_config
+
+ super().generate_stream_gate(params)
+ zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
+
+ response = zhipuai.model_api.sse_invoke(
+ model=self.version,
+ prompt=[{"role": "user", "content": params["prompt"]}],
+ temperature=params.get("temperature"),
+ top_p=params.get("top_p"),
+ incremental=False,
+ )
+ for e in response.events():
+ if e.event == "add":
+ yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
+ # TODO: 更健壮的消息处理
+ # elif e.event == "finish":
+ # ...
+
+ def get_embeddings(self, params):
+ # TODO: 支持embeddings
+ print("embedding")
+ print(params)
+
+
+if __name__ == "__main__":
+ import uvicorn
+ from server.utils import MakeFastAPIOffline
+ from fastchat.serve.model_worker import app
+
+ worker = ChatGLMWorker(
+ controller_addr="http://127.0.0.1:20001",
+ worker_addr="http://127.0.0.1:20003",
+ )
+ sys.modules["fastchat.serve.model_worker"].worker = worker
+ MakeFastAPIOffline(app)
+ uvicorn.run(app, port=20003)
diff --git a/server/utils.py b/server/utils.py
index 0f9d2df0..0e53e3df 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -4,13 +4,17 @@ from typing import List
from fastapi import FastAPI
from pathlib import Path
import asyncio
-from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
-from typing import Literal, Optional
+from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
+from configs.server_config import FSCHAT_MODEL_WORKERS
+import os
+from server import model_workers
+from typing import Literal, Optional, Any
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
+ data: Any = pydantic.Field(None, description="API data")
class Config:
schema_extra = {
@@ -201,10 +205,29 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
- config["device"] = llm_device(config.get("device"))
+
+ # 如果没有设置有效的local_model_path,则认为是在线模型API
+ if not os.path.isdir(config.get("local_model_path", "")):
+ config["online_api"] = True
+ if provider := config.get("provider"):
+ try:
+ config["worker_class"] = getattr(model_workers, provider)
+ except Exception as e:
+ print(f"在线模型 ‘{model_name}’ 的provider没有正确配置")
+
+ config["device"] = llm_device(config.get("device") or LLM_DEVICE)
return config
+def get_all_model_worker_configs() -> dict:
+ result = {}
+ model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
+ for name in model_names:
+ if name != "default":
+ result[name] = get_model_worker_config(name)
+ return result
+
+
def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER
diff --git a/startup.py b/startup.py
index 3a21010e..ecb722c0 100644
--- a/startup.py
+++ b/startup.py
@@ -1,6 +1,7 @@
from multiprocessing import Process, Queue
import multiprocessing as mp
import subprocess
+import asyncio
import sys
import os
from pprint import pprint
@@ -8,6 +9,7 @@ from pprint import pprint
# 设置numexpr最大线程数,默认为CPU核心数
try:
import numexpr
+
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
@@ -16,14 +18,14 @@ except:
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger
-from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
+from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address,
- fschat_openai_api_address, set_httpx_timeout,
- llm_device, embedding_device, get_model_worker_config,
- MakeFastAPIOffline, FastAPI)
+ fschat_openai_api_address, set_httpx_timeout,
+ get_model_worker_config, get_all_model_worker_configs,
+ MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
-from typing import Tuple, List
+from typing import Tuple, List, Dict
from configs import VERSION
@@ -41,6 +43,7 @@ def create_controller_app(
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
+ app._controller = controller
return app
@@ -97,43 +100,62 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
- gptq_config = GptqConfig(
- ckpt=args.gptq_ckpt or args.model_path,
- wbits=args.gptq_wbits,
- groupsize=args.gptq_groupsize,
- act_order=args.gptq_act_order,
- )
- awq_config = AWQConfig(
- ckpt=args.awq_ckpt or args.model_path,
- wbits=args.awq_wbits,
- groupsize=args.awq_groupsize,
- )
+ # 在线模型API
+ if worker_class := kwargs.get("worker_class"):
+ worker = worker_class(model_names=args.model_names,
+ controller_addr=args.controller_address,
+ worker_addr=args.worker_address)
+ # 本地模型
+ else:
+ # workaround to make program exit with Ctrl+c
+ # it should be deleted after pr is merged by fastchat
+ def _new_init_heart_beat(self):
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
+ )
+ self.heart_beat_thread.start()
- worker = ModelWorker(
- controller_addr=args.controller_address,
- worker_addr=args.worker_address,
- worker_id=worker_id,
- model_path=args.model_path,
- model_names=args.model_names,
- limit_worker_concurrency=args.limit_worker_concurrency,
- no_register=args.no_register,
- device=args.device,
- num_gpus=args.num_gpus,
- max_gpu_memory=args.max_gpu_memory,
- load_8bit=args.load_8bit,
- cpu_offloading=args.cpu_offloading,
- gptq_config=gptq_config,
- awq_config=awq_config,
- stream_interval=args.stream_interval,
- conv_template=args.conv_template,
- )
+ ModelWorker.init_heart_beat = _new_init_heart_beat
+
+ gptq_config = GptqConfig(
+ ckpt=args.gptq_ckpt or args.model_path,
+ wbits=args.gptq_wbits,
+ groupsize=args.gptq_groupsize,
+ act_order=args.gptq_act_order,
+ )
+ awq_config = AWQConfig(
+ ckpt=args.awq_ckpt or args.model_path,
+ wbits=args.awq_wbits,
+ groupsize=args.awq_groupsize,
+ )
+
+ worker = ModelWorker(
+ controller_addr=args.controller_address,
+ worker_addr=args.worker_address,
+ worker_id=worker_id,
+ model_path=args.model_path,
+ model_names=args.model_names,
+ limit_worker_concurrency=args.limit_worker_concurrency,
+ no_register=args.no_register,
+ device=args.device,
+ num_gpus=args.num_gpus,
+ max_gpu_memory=args.max_gpu_memory,
+ load_8bit=args.load_8bit,
+ cpu_offloading=args.cpu_offloading,
+ gptq_config=gptq_config,
+ awq_config=awq_config,
+ stream_interval=args.stream_interval,
+ conv_template=args.conv_template,
+ )
+ sys.modules["fastchat.serve.model_worker"].args = args
+ sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker
- sys.modules["fastchat.serve.model_worker"].args = args
- sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
- app.title = f"FastChat LLM Server ({LLM_MODEL})"
+ app.title = f"FastChat LLM Server ({args.model_names[0]})"
+ app._worker = worker
return app
@@ -188,8 +210,11 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
q.put(run_seq)
-def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
+def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO"):
import uvicorn
+ import httpx
+ from fastapi import Body
+ import time
import sys
app = create_controller_app(
@@ -198,11 +223,71 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
)
_set_app_seq(app, q, run_seq)
+ # add interface to release and load model worker
+ @app.post("/release_worker")
+ def release_worker(
+ model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
+ # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
+ new_model_name: str = Body(None, description="释放后加载该模型"),
+ keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
+ ) -> Dict:
+ available_models = app._controller.list_models()
+ if new_model_name in available_models:
+ msg = f"要切换的LLM模型 {new_model_name} 已经存在"
+ logger.info(msg)
+ return {"code": 500, "msg": msg}
+
+ if new_model_name:
+ logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}")
+ else:
+ logger.info(f"即将停止LLM模型: {model_name}")
+
+ if model_name not in available_models:
+ msg = f"the model {model_name} is not available"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
+
+ worker_address = app._controller.get_worker_address(model_name)
+ if not worker_address:
+ msg = f"can not find model_worker address for {model_name}"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
+
+ r = httpx.post(worker_address + "/release",
+ json={"new_model_name": new_model_name, "keep_origin": keep_origin})
+ if r.status_code != 200:
+ msg = f"failed to release model: {model_name}"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
+
+ if new_model_name:
+ timer = 300 # wait 5 minutes for new model_worker register
+ while timer > 0:
+ models = app._controller.list_models()
+ if new_model_name in models:
+ break
+ time.sleep(1)
+ timer -= 1
+ if timer > 0:
+ msg = f"sucess change model from {model_name} to {new_model_name}"
+ logger.info(msg)
+ return {"code": 200, "msg": msg}
+ else:
+ msg = f"failed change model from {model_name} to {new_model_name}"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
+ else:
+ msg = f"sucess to release model: {model_name}"
+ logger.info(msg)
+ return {"code": 200, "msg": msg}
+
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
+
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
+
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
@@ -211,19 +296,20 @@ def run_model_worker(
controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
- log_level: str ="INFO",
+ log_level: str = "INFO",
):
import uvicorn
+ from fastapi import Body
import sys
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
port = kwargs.pop("port")
- model_path = llm_model_dict[model_name].get("local_model_path", "")
- kwargs["model_path"] = model_path
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
- kwargs["worker_address"] = fschat_model_worker_address()
+ kwargs["worker_address"] = fschat_model_worker_address(model_name)
+ model_path = kwargs.get("local_model_path", "")
+ kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq)
@@ -231,6 +317,22 @@ def run_model_worker(
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
+ # add interface to release and load model
+ @app.post("/release")
+ def release_model(
+ new_model_name: str = Body(None, description="释放后加载该模型"),
+ keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
+ ) -> Dict:
+ if keep_origin:
+ if new_model_name:
+ q.put(["start", new_model_name])
+ else:
+ if new_model_name:
+ q.put(["replace", new_model_name])
+ else:
+ q.put(["stop"])
+ return {"code": 200, "msg": "done"}
+
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
@@ -337,6 +439,13 @@ def parse_args() -> argparse.ArgumentParser:
help="run api.py server",
dest="api",
)
+ parser.add_argument(
+ "-p",
+ "--api-worker",
+ action="store_true",
+ help="run online model api such as zhipuai",
+ dest="api_worker",
+ )
parser.add_argument(
"-w",
"--webui",
@@ -368,9 +477,14 @@ def dump_server_info(after_start=False, args=None):
print(f"项目版本:{VERSION}")
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
print("\n")
- print(f"当前LLM模型:{LLM_MODEL} @ {llm_device()}")
- pprint(llm_model_dict[LLM_MODEL])
+
+ model = LLM_MODEL
+ if args and args.model_name:
+ model = args.model_name
+ print(f"当前LLM模型:{model} @ {llm_device()}")
+ pprint(llm_model_dict[model])
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
+
if after_start:
print("\n")
print(f"服务端运行信息:")
@@ -385,10 +499,15 @@ def dump_server_info(after_start=False, args=None):
print("\n")
-if __name__ == "__main__":
+async def start_main_server():
import time
mp.set_start_method("spawn")
+ # TODO 链式启动的队列,确实可以用于控制启动顺序,
+ # 但目前引入proxy_worker后,启动的独立于框架的work processes无法确认当前的位置,
+ # 导致注册器未启动时,无法注册。整个启动链因为异常被终止
+ # 使用await asyncio.sleep(3)可以让后续代码等待一段时间,但不是最优解
+
queue = Queue()
args, parser = parse_args()
@@ -396,17 +515,20 @@ if __name__ == "__main__":
args.openai_api = True
args.model_worker = True
args.api = True
+ args.api_worker = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
+ args.api_worker = True
args.webui = False
elif args.llm_api:
args.openai_api = True
args.model_worker = True
+ args.api_worker = True
args.api = False
args.webui = False
@@ -416,7 +538,11 @@ if __name__ == "__main__":
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
- processes = {}
+ processes = {"online-api": []}
+
+ def process_count():
+ return len(processes) + len(processes["online-api"]) - 1
+
if args.quiet:
log_level = "ERROR"
else:
@@ -426,38 +552,52 @@ if __name__ == "__main__":
process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
- args=(queue, len(processes) + 1, log_level),
+ args=(queue, process_count() + 1, log_level),
daemon=True,
)
process.start()
+ await asyncio.sleep(3)
processes["controller"] = process
process = Process(
target=run_openai_api,
name=f"openai_api({os.getpid()})",
- args=(queue, len(processes) + 1),
+ args=(queue, process_count() + 1),
daemon=True,
)
process.start()
processes["openai_api"] = process
if args.model_worker:
- model_path = llm_model_dict[args.model_name].get("local_model_path", "")
- if os.path.isdir(model_path):
+ config = get_model_worker_config(args.model_name)
+ if not config.get("online_api"):
process = Process(
target=run_model_worker,
- name=f"model_worker({os.getpid()})",
- args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level),
+ name=f"model_worker - {args.model_name} ({os.getpid()})",
+ args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
process.start()
processes["model_worker"] = process
+ if args.api_worker:
+ configs = get_all_model_worker_configs()
+ for model_name, config in configs.items():
+ if config.get("online_api") and config.get("worker_class"):
+ process = Process(
+ target=run_model_worker,
+ name=f"model_worker - {model_name} ({os.getpid()})",
+ args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
+ daemon=True,
+ )
+ process.start()
+ processes["online-api"].append(process)
+
if args.api:
process = Process(
target=run_api_server,
name=f"API Server{os.getpid()})",
- args=(queue, len(processes) + 1),
+ args=(queue, process_count() + 1),
daemon=True,
)
process.start()
@@ -467,39 +607,53 @@ if __name__ == "__main__":
process = Process(
target=run_webui,
name=f"WEBUI Server{os.getpid()})",
- args=(queue, len(processes) + 1),
+ args=(queue, process_count() + 1),
daemon=True,
)
process.start()
processes["webui"] = process
- if len(processes) == 0:
+ if process_count() == 0:
parser.print_help()
else:
try:
- # log infors
while True:
no = queue.get()
- if no == len(processes):
+ if no == process_count():
time.sleep(0.5)
dump_server_info(after_start=True, args=args)
break
else:
queue.put(no)
- if model_worker_process := processes.get("model_worker"):
+ if model_worker_process := processes.pop("model_worker", None):
model_worker_process.join()
+ for process in processes.pop("online-api", []):
+ process.join()
for name, process in processes.items():
- if name != "model_worker":
- process.join()
+ process.join()
except:
- if model_worker_process := processes.get("model_worker"):
+ if model_worker_process := processes.pop("model_worker", None):
model_worker_process.terminate()
+ for process in processes.pop("online-api", []):
+ process.terminate()
for name, process in processes.items():
- if name != "model_worker":
- process.terminate()
+ process.terminate()
+if __name__ == "__main__":
+
+ if sys.version_info < (3, 10):
+ loop = asyncio.get_event_loop()
+ else:
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+
+ asyncio.set_event_loop(loop)
+ # 同步调用协程代码
+ loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py
new file mode 100644
index 00000000..f348fe74
--- /dev/null
+++ b/tests/api/test_llm_api.py
@@ -0,0 +1,74 @@
+import requests
+import json
+import sys
+from pathlib import Path
+
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
+from configs.model_config import LLM_MODEL, llm_model_dict
+
+from pprint import pprint
+import random
+
+
+def get_configured_models():
+ model_workers = list(FSCHAT_MODEL_WORKERS)
+ if "default" in model_workers:
+ model_workers.remove("default")
+
+ llm_dict = list(llm_model_dict)
+
+ return model_workers, llm_dict
+
+
+api_base_url = api_address()
+
+
+def get_running_models(api="/llm_model/list_models"):
+ url = api_base_url + api
+ r = requests.post(url)
+ if r.status_code == 200:
+ return r.json()["data"]
+ return []
+
+
+def test_running_models(api="/llm_model/list_models"):
+ url = api_base_url + api
+ r = requests.post(url)
+ assert r.status_code == 200
+ print("\n获取当前正在运行的模型列表:")
+ pprint(r.json())
+ assert isinstance(r.json()["data"], list)
+ assert len(r.json()["data"]) > 0
+
+
+# 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动
+# def test_stop_model(api="/llm_model/stop"):
+# url = api_base_url + api
+# r = requests.post(url, json={""})
+
+
+def test_change_model(api="/llm_model/change"):
+ url = api_base_url + api
+
+ running_models = get_running_models()
+ assert len(running_models) > 0
+
+ model_workers, llm_dict = get_configured_models()
+
+ availabel_new_models = set(model_workers) - set(running_models)
+ if len(availabel_new_models) == 0:
+ availabel_new_models = set(llm_dict) - set(running_models)
+ availabel_new_models = list(availabel_new_models)
+ assert len(availabel_new_models) > 0
+ print(availabel_new_models)
+
+ model_name = random.choice(running_models)
+ new_model_name = random.choice(availabel_new_models)
+ print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
+ r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
+ assert r.status_code == 200
+
+ running_models = get_running_models()
+ assert new_model_name in running_models
diff --git a/tests/document_loader/test_imgloader.py b/tests/document_loader/test_imgloader.py
new file mode 100644
index 00000000..8bba7da9
--- /dev/null
+++ b/tests/document_loader/test_imgloader.py
@@ -0,0 +1,21 @@
+import sys
+from pathlib import Path
+
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+from pprint import pprint
+
+test_files = {
+ "ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
+}
+
+def test_rapidocrpdfloader():
+ pdf_path = test_files["ocr_test.pdf"]
+ from document_loaders import RapidOCRPDFLoader
+
+ loader = RapidOCRPDFLoader(pdf_path)
+ docs = loader.load()
+ pprint(docs)
+ assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)
+
+
diff --git a/tests/document_loader/test_pdfloader.py b/tests/document_loader/test_pdfloader.py
new file mode 100644
index 00000000..92460cb4
--- /dev/null
+++ b/tests/document_loader/test_pdfloader.py
@@ -0,0 +1,21 @@
+import sys
+from pathlib import Path
+
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+from pprint import pprint
+
+test_files = {
+ "ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
+}
+
+def test_rapidocrloader():
+ img_path = test_files["ocr_test.jpg"]
+ from document_loaders import RapidOCRLoader
+
+ loader = RapidOCRLoader(img_path)
+ docs = loader.load()
+ pprint(docs)
+ assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)
+
+
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index 2d5e260b..730b6420 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -1,10 +1,14 @@
import streamlit as st
+from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
-from typing import List, Dict
import os
+from configs.model_config import llm_model_dict, LLM_MODEL
+from server.utils import get_model_worker_config
+from typing import List, Dict
+
chat_box = ChatBox(
assistant_avatar=os.path.join(
@@ -59,6 +63,38 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change,
key="dialogue_mode",
)
+
+ def on_llm_change():
+ st.session_state["prev_llm_model"] = llm_model
+
+ def llm_model_format_func(x):
+ if x in running_models:
+ return f"{x} (Running)"
+ return x
+
+ running_models = api.list_running_models()
+ config_models = api.list_config_models()
+ for x in running_models:
+ if x in config_models:
+ config_models.remove(x)
+ llm_models = running_models + config_models
+ if "prev_llm_model" not in st.session_state:
+ index = llm_models.index(LLM_MODEL)
+ else:
+ index = 0
+ llm_model = st.selectbox("选择LLM模型:",
+ llm_models,
+ index,
+ format_func=llm_model_format_func,
+ on_change=on_llm_change,
+ # key="llm_model",
+ )
+ if (st.session_state.get("prev_llm_model") != llm_model
+ and not get_model_worker_config(llm_model).get("online_api")):
+ with st.spinner(f"正在加载模型: {llm_model}"):
+ r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
+ st.session_state["prev_llm_model"] = llm_model
+
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change():
@@ -99,7 +135,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
- r = api.chat_chat(prompt, history)
+ r = api.chat_chat(prompt, history=history, model=llm_model)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
@@ -114,7 +150,7 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
- for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
+ for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
@@ -127,8 +163,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="网络搜索结果"),
])
text = ""
- for d in api.search_engine_chat(prompt, search_engine, se_top_k):
- if error_msg := check_error_msg(d): # check whether error occured
+ for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
+ if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
else:
text += d["answer"]
diff --git a/webui_pages/utils.py b/webui_pages/utils.py
index 58b08e87..08511044 100644
--- a/webui_pages/utils.py
+++ b/webui_pages/utils.py
@@ -6,12 +6,14 @@ from configs.model_config import (
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
+ llm_model_dict,
HISTORY_LEN,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
logger,
)
+from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn
@@ -42,7 +44,7 @@ class ApiRequest:
def __init__(
self,
base_url: str = "http://127.0.0.1:7861",
- timeout: float = 60.0,
+ timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly
):
self.base_url = base_url
@@ -289,6 +291,7 @@ class ApiRequest:
query: str,
history: List[Dict] = [],
stream: bool = True,
+ model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@@ -301,6 +304,7 @@ class ApiRequest:
"query": query,
"history": history,
"stream": stream,
+ "model_name": model,
}
print(f"received input message:")
@@ -322,6 +326,7 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
+ model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@@ -337,6 +342,7 @@ class ApiRequest:
"score_threshold": score_threshold,
"history": history,
"stream": stream,
+ "model_name": model,
"local_doc_url": no_remote_api,
}
@@ -361,6 +367,7 @@ class ApiRequest:
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True,
+ model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@@ -374,6 +381,7 @@ class ApiRequest:
"search_engine_name": search_engine_name,
"top_k": top_k,
"stream": stream,
+ "model_name": model,
}
print(f"received input message:")
@@ -645,6 +653,84 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
+ def list_running_models(self, controller_address: str = None):
+ '''
+ 获取Fastchat中正运行的模型列表
+ '''
+ r = self.post(
+ "/llm_model/list_models",
+ )
+ return r.json().get("data", [])
+
+ def list_config_models(self):
+ '''
+ 获取configs中配置的模型列表
+ '''
+ return list(llm_model_dict.keys())
+
+ def stop_llm_model(
+ self,
+ model_name: str,
+ controller_address: str = None,
+ ):
+ '''
+ 停止某个LLM模型。
+ 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
+ '''
+ data = {
+ "model_name": model_name,
+ "controller_address": controller_address,
+ }
+ r = self.post(
+ "/llm_model/stop",
+ json=data,
+ )
+ return r.json()
+
+ def change_llm_model(
+ self,
+ model_name: str,
+ new_model_name: str,
+ controller_address: str = None,
+ ):
+ '''
+ 向fastchat controller请求切换LLM模型。
+ '''
+ if not model_name or not new_model_name:
+ return
+
+ if new_model_name == model_name:
+ return {
+ "code": 200,
+ "msg": "什么都不用做"
+ }
+
+ running_models = self.list_running_models()
+ if model_name not in running_models:
+ return {
+ "code": 500,
+ "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
+ }
+
+ config_models = self.list_config_models()
+ if new_model_name not in config_models:
+ return {
+ "code": 500,
+ "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
+ }
+
+ data = {
+ "model_name": model_name,
+ "new_model_name": new_model_name,
+ "controller_address": controller_address,
+ }
+ r = self.post(
+ "/llm_model/change",
+ json=data,
+ timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
+ )
+ return r.json()
+
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''