添加切换模型功能,支持智谱AI在线模型 (#1342)

* 添加LLM模型切换功能,需要在server_config中设置可切换的模型
* add tests for api.py/llm_model/*
* - 支持模型切换
- 支持智普AI线上模型
- startup.py增加参数`--api-worker`,自动运行所有的线上API模型。使用`-a
  (--all-webui), --all-api`时默认开启该选项
* 修复被fastchat覆盖的标准输出
* 对fastchat日志进行更细致的控制,startup.py中增加-q(--quiet)开关,可以减少无用的fastchat日志输出
* 修正chatglm api的对话模板


Co-authored-by: liunux4odoo <liunu@qq.com>
This commit is contained in:
liunux4odoo 2023-09-01 23:58:09 +08:00 committed by GitHub
parent ab4c8d2e5d
commit 6cb1bdf623
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 703 additions and 94 deletions

View File

@ -7,6 +7,19 @@ logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT) logging.basicConfig(format=LOG_FORMAT)
# 分布式部署时不运行LLM的机器上可以不装torch
def default_device():
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
# 在以下字典中修改属性值以指定本地embedding模型存储位置 # 在以下字典中修改属性值以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径 # 此处请写绝对路径
@ -34,7 +47,6 @@ EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto" EMBEDDING_DEVICE = "auto"
llm_model_dict = { llm_model_dict = {
"chatglm-6b": { "chatglm-6b": {
"local_model_path": "THUDM/chatglm-6b", "local_model_path": "THUDM/chatglm-6b",
@ -74,6 +86,15 @@ llm_model_dict = {
"api_key": os.environ.get("OPENAI_API_KEY"), "api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY") "openai_proxy": os.environ.get("OPENAI_PROXY")
}, },
# 线上模型。当前支持智谱AI。
# 如果没有设置有效的local_model_path则认为是在线模型API。
# 请在server_config中为每个在线API设置不同的端口
"chatglm-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": os.environ.get("ZHIPUAI_API_KEY"),
"provider": "ChatGLMWorker",
"version": "chatglm_pro",
},
} }
# LLM 名称 # LLM 名称

View File

@ -1,4 +1,8 @@
from .model_config import LLM_MODEL, LLM_DEVICE from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
import httpx
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
HTTPX_DEFAULT_TIMEOUT = 300.0
# API 是否开启跨域默认为False如果需要开启请设置为True # API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain # is open cross domain
@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = {
# 这些模型必须是在model_config.llm_model_dict中正确配置的。 # 这些模型必须是在model_config.llm_model_dict中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL # 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = { FSCHAT_MODEL_WORKERS = {
LLM_MODEL: { # 所有模型共用的默认配置可在模型专项配置或llm_model_dict中进行覆盖。
"default": {
"host": DEFAULT_BIND_HOST, "host": DEFAULT_BIND_HOST,
"port": 20002, "port": 20002,
"device": LLM_DEVICE, "device": LLM_DEVICE,
# todo: 多卡加载需要配置的参数
# 多卡加载需要配置的参数
# "gpus": None, # 使用的GPU以str的格式指定如"0,1" # "gpus": None, # 使用的GPU以str的格式指定如"0,1"
# "num_gpus": 1, # 使用GPU的数量 # "num_gpus": 1, # 使用GPU的数量
# 以下为非常用参数,可根据需要配置
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存 # "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
# 以下为非常用参数,可根据需要配置
# "load_8bit": False, # 开启8bit量化 # "load_8bit": False, # 开启8bit量化
# "cpu_offloading": None, # "cpu_offloading": None,
# "gptq_ckpt": None, # "gptq_ckpt": None,
@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2, # "stream_interval": 2,
# "no_register": False, # "no_register": False,
}, },
"baichuan-7b": { # 使用default中的IP和端口
"device": "cpu",
},
"chatglm-api": { # 请为每个在线API设置不同的端口
"port": 20003,
},
} }
# fastchat multi model worker server # fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = { FSCHAT_MULTI_MODEL_WORKERS = {
# todo # TODO:
} }
# fastchat controller server # fastchat controller server

View File

@ -4,11 +4,12 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
from configs import VERSION from configs import VERSION
import argparse import argparse
import uvicorn import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat, from server.chat import (chat, knowledge_base_chat, openai_chat,
@ -17,7 +18,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc, from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store, update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore) search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
import httpx
from typing import List from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -123,6 +125,75 @@ def create_app():
summary="根据content中文档重建向量库流式输出处理进度。" summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store) )(recreate_vector_store)
# LLM模型相关接口
@app.post("/llm_model/list_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型")
def list_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
return app return app

View File

@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
{"role": "assistant", "content": "虎头虎脑"}]] {"role": "assistant", "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
): ):
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
async def chat_iterator(query: str, async def chat_iterator(query: str,
history: List[History] = [], history: List[History] = [],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
streaming=True, streaming=True,
verbose=True, verbose=True,
callbacks=[callback], callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=LLM_MODEL, model_name=model_name,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") openai_proxy=llm_model_dict[model_name].get("openai_proxy")
) )
input_msg = History(role="user", content="{{ input }}").to_msg_template(False) input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
await task await task
return StreamingResponse(chat_iterator(query, history), return StreamingResponse(chat_iterator(query, history, model_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"content": "虎头虎脑"}]] "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None, request: Request = None,
): ):
@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
kb: KBService, kb: KBService,
top_k: int, top_k: int,
history: Optional[List[History]], history: Optional[List[History]],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI( model = ChatOpenAI(
streaming=True, streaming=True,
verbose=True, verbose=True,
callbacks=[callback], callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=LLM_MODEL, model_name=model_name,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") openai_proxy=llm_model_dict[model_name].get("openai_proxy")
) )
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
await task await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history), return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
"content": "虎头虎脑"}]] "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
): ):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str, search_engine_name: str,
top_k: int, top_k: int,
history: Optional[List[History]], history: Optional[List[History]],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI( model = ChatOpenAI(
streaming=True, streaming=True,
verbose=True, verbose=True,
callbacks=[callback], callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=LLM_MODEL, model_name=model_name,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy") openai_proxy=llm_model_dict[model_name].get("openai_proxy")
) )
docs = lookup_search_engine(query, search_engine_name, top_k) docs = lookup_search_engine(query, search_engine_name, top_k)
@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
ensure_ascii=False) ensure_ascii=False)
await task await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -6,7 +6,8 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text from sqlalchemy import text
from configs.model_config import kbs_config from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile from server.knowledge_base.utils import load_embeddings, KnowledgeFile

View File

@ -0,0 +1 @@
from .zhipu import ChatGLMWorker

View File

@ -0,0 +1,71 @@
from configs.model_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
import threading
from typing import Dict, List
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiModelOutMsg(BaseModel):
error_code: int = 0
text: str
class ApiModelWorker(BaseModelWorker):
BASE_URL: str
SUPPORT_MODELS: List
def __init__(
self,
model_names: List[str],
controller_addr: str,
worker_addr: str,
context_len: int = 2048,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
print("count token")
print(params)
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params):
self.call_ct += 1
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def get_embeddings(self, params):
print("embedding")
print(params)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()

View File

@ -0,0 +1,75 @@
import zhipuai
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal
class ChatGLMWorker(ApiModelWorker):
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
def __init__(
self,
*,
model_names: List[str] = ["chatglm-api"],
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name="chatglm-api",
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["Human", "Assistant"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
# TODO: 支持stream参数维护request_id传过来的prompt也有问题
from server.utils import get_model_worker_config
super().generate_stream_gate(params)
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
response = zhipuai.model_api.sse_invoke(
model=self.version,
prompt=[{"role": "user", "content": params["prompt"]}],
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
)
for e in response.events():
if e.event == "add":
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
# TODO: 更健壮的消息处理
# elif e.event == "finish":
# ...
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)

View File

@ -5,13 +5,17 @@ import torch
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
from typing import Literal, Optional from configs.server_config import FSCHAT_MODEL_WORKERS
import os
from server import model_workers
from typing import Literal, Optional, Any
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code") code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message") msg: str = pydantic.Field("success", description="API status message")
data: Any = pydantic.Field(None, description="API data")
class Config: class Config:
schema_extra = { schema_extra = {
@ -201,10 +205,28 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {})) config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
config["device"] = llm_device(config.get("device"))
# 如果没有设置有效的local_model_path则认为是在线模型API
if not os.path.isdir(config.get("local_model_path", "")):
config["online_api"] = True
if provider := config.get("provider"):
try:
config["worker_class"] = getattr(model_workers, provider)
except Exception as e:
print(f"在线模型 {model_name} 的provider没有正确配置")
return config return config
def get_all_model_worker_configs() -> dict:
result = {}
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
for name in model_names:
if name != "default":
result[name] = get_model_worker_config(name)
return result
def fschat_controller_address() -> str: def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER from configs.server_config import FSCHAT_CONTROLLER

View File

@ -16,14 +16,14 @@ except:
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, ) FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout, fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config) get_model_worker_config, get_all_model_worker_configs,
from server.utils import MakeFastAPIOffline, FastAPI MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse import argparse
from typing import Tuple, List from typing import Tuple, List, Dict
from configs import VERSION from configs import VERSION
@ -41,6 +41,7 @@ def create_controller_app(
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
app.title = "FastChat Controller" app.title = "FastChat Controller"
app._controller = controller
return app return app
@ -97,43 +98,62 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse
) )
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig( # 在线模型API
ckpt=args.gptq_ckpt or args.model_path, if worker_class := kwargs.get("worker_class"):
wbits=args.gptq_wbits, worker = worker_class(model_names=args.model_names,
groupsize=args.gptq_groupsize, controller_addr=args.controller_address,
act_order=args.gptq_act_order, worker_addr=args.worker_address)
) # 本地模型
awq_config = AWQConfig( else:
ckpt=args.awq_ckpt or args.model_path, # workaround to make program exit with Ctrl+c
wbits=args.awq_wbits, # it should be deleted after pr is merged by fastchat
groupsize=args.awq_groupsize, def _new_init_heart_beat(self):
) self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
worker = ModelWorker( ModelWorker.init_heart_beat = _new_init_heart_beat
controller_addr=args.controller_address,
worker_addr=args.worker_address, gptq_config = GptqConfig(
worker_id=worker_id, ckpt=args.gptq_ckpt or args.model_path,
model_path=args.model_path, wbits=args.gptq_wbits,
model_names=args.model_names, groupsize=args.gptq_groupsize,
limit_worker_concurrency=args.limit_worker_concurrency, act_order=args.gptq_act_order,
no_register=args.no_register, )
device=args.device, awq_config = AWQConfig(
num_gpus=args.num_gpus, ckpt=args.awq_ckpt or args.model_path,
max_gpu_memory=args.max_gpu_memory, wbits=args.awq_wbits,
load_8bit=args.load_8bit, groupsize=args.awq_groupsize,
cpu_offloading=args.cpu_offloading, )
gptq_config=gptq_config,
awq_config=awq_config, worker = ModelWorker(
stream_interval=args.stream_interval, controller_addr=args.controller_address,
conv_template=args.conv_template, worker_addr=args.worker_address,
) worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app) MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})" app.title = f"FastChat LLM Server ({args.model_names[0]})"
app._worker = worker
return app return app
@ -190,6 +210,9 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"): def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
import uvicorn import uvicorn
import httpx
from fastapi import Body
import time
import sys import sys
app = create_controller_app( app = create_controller_app(
@ -198,11 +221,71 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str ="INFO"):
) )
_set_app_seq(app, q, run_seq) _set_app_seq(app, q, run_seq)
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
available_models = app._controller.list_models()
if new_model_name in available_models:
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
logger.info(msg)
return {"code": 500, "msg": msg}
if new_model_name:
logger.info(f"开始切换LLM模型{model_name}{new_model_name}")
else:
logger.info(f"即将停止LLM模型 {model_name}")
if model_name not in available_models:
msg = f"the model {model_name} is not available"
logger.error(msg)
return {"code": 500, "msg": msg}
worker_address = app._controller.get_worker_address(model_name)
if not worker_address:
msg = f"can not find model_worker address for {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
r = httpx.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = 300 # wait 5 minutes for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
break
time.sleep(1)
timer -= 1
if timer > 0:
msg = f"sucess change model from {model_name} to {new_model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
else:
msg = f"failed change model from {model_name} to {new_model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
else:
msg = f"sucess to release model: {model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
host = FSCHAT_CONTROLLER["host"] host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"] port = FSCHAT_CONTROLLER["port"]
if log_level == "ERROR": if log_level == "ERROR":
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
@ -214,16 +297,17 @@ def run_model_worker(
log_level: str ="INFO", log_level: str ="INFO",
): ):
import uvicorn import uvicorn
from fastapi import Body
import sys import sys
kwargs = get_model_worker_config(model_name) kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host") host = kwargs.pop("host")
port = kwargs.pop("port") port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "")
kwargs["model_path"] = model_path
kwargs["model_names"] = [model_name] kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address() kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address() kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("local_model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs) app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq) _set_app_seq(app, q, run_seq)
@ -231,6 +315,22 @@ def run_model_worker(
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put(["start", new_model_name])
else:
if new_model_name:
q.put(["replace", new_model_name])
else:
q.put(["stop"])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
@ -337,6 +437,13 @@ def parse_args() -> argparse.ArgumentParser:
help="run api.py server", help="run api.py server",
dest="api", dest="api",
) )
parser.add_argument(
"-p",
"--api-worker",
action="store_true",
help="run online model api such as zhipuai",
dest="api_worker",
)
parser.add_argument( parser.add_argument(
"-w", "-w",
"--webui", "--webui",
@ -368,9 +475,14 @@ def dump_server_info(after_start=False, args=None):
print(f"项目版本:{VERSION}") print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}") print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n") print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {llm_device()}")
pprint(llm_model_dict[LLM_MODEL]) model = LLM_MODEL
if args and args.model_name:
model = args.model_name
print(f"当前LLM模型{model} @ {llm_device()}")
pprint(llm_model_dict[model])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}") print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start: if after_start:
print("\n") print("\n")
print(f"服务端运行信息:") print(f"服务端运行信息:")
@ -396,17 +508,20 @@ if __name__ == "__main__":
args.openai_api = True args.openai_api = True
args.model_worker = True args.model_worker = True
args.api = True args.api = True
args.api_worker = True
args.webui = True args.webui = True
elif args.all_api: elif args.all_api:
args.openai_api = True args.openai_api = True
args.model_worker = True args.model_worker = True
args.api = True args.api = True
args.api_worker = True
args.webui = False args.webui = False
elif args.llm_api: elif args.llm_api:
args.openai_api = True args.openai_api = True
args.model_worker = True args.model_worker = True
args.api_worker = True
args.api = False args.api = False
args.webui = False args.webui = False
@ -416,7 +531,11 @@ if __name__ == "__main__":
logger.info(f"正在启动服务:") logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {} processes = {"online-api": []}
def process_count():
return len(processes) + len(processes["online-api"]) -1
if args.quiet: if args.quiet:
log_level = "ERROR" log_level = "ERROR"
else: else:
@ -426,7 +545,7 @@ if __name__ == "__main__":
process = Process( process = Process(
target=run_controller, target=run_controller,
name=f"controller({os.getpid()})", name=f"controller({os.getpid()})",
args=(queue, len(processes) + 1, log_level), args=(queue, process_count() + 1, log_level),
daemon=True, daemon=True,
) )
process.start() process.start()
@ -435,29 +554,42 @@ if __name__ == "__main__":
process = Process( process = Process(
target=run_openai_api, target=run_openai_api,
name=f"openai_api({os.getpid()})", name=f"openai_api({os.getpid()})",
args=(queue, len(processes) + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start() process.start()
processes["openai_api"] = process processes["openai_api"] = process
if args.model_worker: if args.model_worker:
model_path = llm_model_dict[args.model_name].get("local_model_path", "") config = get_model_worker_config(args.model_name)
if os.path.isdir(model_path): if not config.get("online_api"):
process = Process( process = Process(
target=run_model_worker, target=run_model_worker,
name=f"model_worker({os.getpid()})", name=f"model_worker - {args.model_name} ({os.getpid()})",
args=(args.model_name, args.controller_address, queue, len(processes) + 1, log_level), args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True, daemon=True,
) )
process.start() process.start()
processes["model_worker"] = process processes["model_worker"] = process
if args.api_worker:
configs = get_all_model_worker_configs()
for model_name, config in configs.items():
if config.get("online_api") and config.get("worker_class"):
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name} ({os.getpid()})",
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
process.start()
processes["online-api"].append(process)
if args.api: if args.api:
process = Process( process = Process(
target=run_api_server, target=run_api_server,
name=f"API Server{os.getpid()})", name=f"API Server{os.getpid()})",
args=(queue, len(processes) + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start() process.start()
@ -467,37 +599,38 @@ if __name__ == "__main__":
process = Process( process = Process(
target=run_webui, target=run_webui,
name=f"WEBUI Server{os.getpid()})", name=f"WEBUI Server{os.getpid()})",
args=(queue, len(processes) + 1), args=(queue, process_count() + 1),
daemon=True, daemon=True,
) )
process.start() process.start()
processes["webui"] = process processes["webui"] = process
if len(processes) == 0: if process_count() == 0:
parser.print_help() parser.print_help()
else: else:
try: try:
# log infors
while True: while True:
no = queue.get() no = queue.get()
if no == len(processes): if no == process_count():
time.sleep(0.5) time.sleep(0.5)
dump_server_info(after_start=True, args=args) dump_server_info(after_start=True, args=args)
break break
else: else:
queue.put(no) queue.put(no)
if model_worker_process := processes.get("model_worker"): if model_worker_process := processes.pop("model_worker", None):
model_worker_process.join() model_worker_process.join()
for process in processes.pop("online-api", []):
process.join()
for name, process in processes.items(): for name, process in processes.items():
if name != "model_worker": process.join()
process.join()
except: except:
if model_worker_process := processes.get("model_worker"): if model_worker_process := processes.pop("model_worker", None):
model_worker_process.terminate() model_worker_process.terminate()
for process in processes.pop("online-api", []):
process.terminate()
for name, process in processes.items(): for name, process in processes.items():
if name != "model_worker": process.terminate()
process.terminate()
# 服务启动后接口调用示例: # 服务启动后接口调用示例:

View File

@ -1,4 +1,3 @@
from doctest import testfile
import requests import requests
import json import json
import sys import sys

74
tests/api/test_llm_api.py Normal file
View File

@ -0,0 +1,74 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict
from pprint import pprint
import random
def get_configured_models():
model_workers = list(FSCHAT_MODEL_WORKERS)
if "default" in model_workers:
model_workers.remove("default")
llm_dict = list(llm_model_dict)
return model_workers, llm_dict
api_base_url = api_address()
def get_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
if r.status_code == 200:
return r.json()["data"]
return []
def test_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
assert r.status_code == 200
print("\n获取当前正在运行的模型列表:")
pprint(r.json())
assert isinstance(r.json()["data"], list)
assert len(r.json()["data"]) > 0
# 不建议使用stop_model功能。按现在的实现停止了就只能手动再启动
# def test_stop_model(api="/llm_model/stop"):
# url = api_base_url + api
# r = requests.post(url, json={""})
def test_change_model(api="/llm_model/change"):
url = api_base_url + api
running_models = get_running_models()
assert len(running_models) > 0
model_workers, llm_dict = get_configured_models()
availabel_new_models = set(model_workers) - set(running_models)
if len(availabel_new_models) == 0:
availabel_new_models = set(llm_dict) - set(running_models)
availabel_new_models = list(availabel_new_models)
assert len(availabel_new_models) > 0
print(availabel_new_models)
model_name = random.choice(running_models)
new_model_name = random.choice(availabel_new_models)
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
assert r.status_code == 200
running_models = get_running_models()
assert new_model_name in running_models

View File

@ -1,10 +1,14 @@
import streamlit as st import streamlit as st
from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_chatbox import * from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES from server.chat.search_engine_chat import SEARCH_ENGINES
from typing import List, Dict
import os import os
from configs.model_config import llm_model_dict, LLM_MODEL
from server.utils import get_model_worker_config
from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
@ -59,6 +63,38 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change, on_change=on_mode_change,
key="dialogue_mode", key="dialogue_mode",
) )
def on_llm_change():
st.session_state["prev_llm_model"] = llm_model
def llm_model_format_func(x):
if x in running_models:
return f"{x} (Running)"
return x
running_models = api.list_running_models()
config_models = api.list_config_models()
for x in running_models:
if x in config_models:
config_models.remove(x)
llm_models = running_models + config_models
if "prev_llm_model" not in st.session_state:
index = llm_models.index(LLM_MODEL)
else:
index = 0
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
# key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")):
with st.spinner(f"正在加载模型: {llm_model}"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["prev_llm_model"] = llm_model
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change(): def on_kb_change():
@ -99,7 +135,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "LLM 对话": if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...") chat_box.ai_say("正在思考...")
text = "" text = ""
r = api.chat_chat(prompt, history) r = api.chat_chat(prompt, history=history, model=llm_model)
for t in r: for t in r:
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg) st.error(error_msg)
@ -114,7 +150,7 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"), Markdown("...", in_expander=True, title="知识库匹配结果"),
]) ])
text = "" text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history): for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
text += d["answer"] text += d["answer"]
@ -127,8 +163,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="网络搜索结果"), Markdown("...", in_expander=True, title="网络搜索结果"),
]) ])
text = "" text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k): for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
else: else:
text += d["answer"] text += d["answer"]

View File

@ -6,12 +6,14 @@ from configs.model_config import (
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
KB_ROOT_PATH, KB_ROOT_PATH,
LLM_MODEL, LLM_MODEL,
llm_model_dict,
HISTORY_LEN, HISTORY_LEN,
SCORE_THRESHOLD, SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
logger, logger,
) )
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx import httpx
import asyncio import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn from server.chat.openai_chat import OpenAiChatMsgIn
@ -42,7 +44,7 @@ class ApiRequest:
def __init__( def __init__(
self, self,
base_url: str = "http://127.0.0.1:7861", base_url: str = "http://127.0.0.1:7861",
timeout: float = 60.0, timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly no_remote_api: bool = False, # call api view function directly
): ):
self.base_url = base_url self.base_url = base_url
@ -289,6 +291,7 @@ class ApiRequest:
query: str, query: str,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -301,6 +304,7 @@ class ApiRequest:
"query": query, "query": query,
"history": history, "history": history,
"stream": stream, "stream": stream,
"model_name": model,
} }
print(f"received input message:") print(f"received input message:")
@ -322,6 +326,7 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -337,6 +342,7 @@ class ApiRequest:
"score_threshold": score_threshold, "score_threshold": score_threshold,
"history": history, "history": history,
"stream": stream, "stream": stream,
"model_name": model,
"local_doc_url": no_remote_api, "local_doc_url": no_remote_api,
} }
@ -361,6 +367,7 @@ class ApiRequest:
search_engine_name: str, search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K, top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -374,6 +381,7 @@ class ApiRequest:
"search_engine_name": search_engine_name, "search_engine_name": search_engine_name,
"top_k": top_k, "top_k": top_k,
"stream": stream, "stream": stream,
"model_name": model,
} }
print(f"received input message:") print(f"received input message:")
@ -645,6 +653,84 @@ class ApiRequest:
) )
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
def list_running_models(self, controller_address: str = None):
'''
获取Fastchat中正运行的模型列表
'''
r = self.post(
"/llm_model/list_models",
)
return r.json().get("data", [])
def list_config_models(self):
'''
获取configs中配置的模型列表
'''
return list(llm_model_dict.keys())
def stop_llm_model(
self,
model_name: str,
controller_address: str = None,
):
'''
停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
data = {
"model_name": model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
):
'''
向fastchat controller请求切换LLM模型
'''
if not model_name or not new_model_name:
return
if new_model_name == model_name:
return {
"code": 200,
"msg": "什么都不用做"
}
running_models = self.list_running_models()
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = self.list_config_models()
if new_model_name not in config_models:
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
''' '''