mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-02 12:46:56 +08:00
集成openai plugins插件
This commit is contained in:
parent
8063aab7a1
commit
48fb6b83fd
@ -8,7 +8,7 @@ torchaudio>=2.1.2
|
||||
# Langchain 0.1.x requirements
|
||||
langchain>=0.1.0
|
||||
langchain_openai>=0.0.2
|
||||
langchain-community>=1.0.0
|
||||
langchain-community>=0.0.11
|
||||
langchainhub>=0.1.14
|
||||
|
||||
pydantic==1.10.13
|
||||
|
||||
@ -17,9 +17,7 @@ from server.chat.chat import chat
|
||||
from server.chat.completion import completion
|
||||
from server.chat.feedback import chat_feedback
|
||||
from server.embeddings_api import embed_texts_endpoint
|
||||
from server.llm_api import (list_running_models, list_config_models,
|
||||
change_llm_model, stop_llm_model,
|
||||
get_model_config)
|
||||
|
||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||
get_server_configs, get_prompt_template)
|
||||
from typing import List, Literal
|
||||
@ -73,32 +71,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
# 摘要相关接口
|
||||
mount_filename_summary_routes(app)
|
||||
|
||||
# LLM模型相关接口
|
||||
app.post("/llm_model/list_running_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出当前已加载的模型",
|
||||
)(list_running_models)
|
||||
|
||||
app.post("/llm_model/list_config_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出configs已配置的模型",
|
||||
)(list_config_models)
|
||||
|
||||
app.post("/llm_model/get_model_config",
|
||||
tags=["LLM Model Management"],
|
||||
summary="获取模型配置(合并后)",
|
||||
)(get_model_config)
|
||||
|
||||
app.post("/llm_model/stop",
|
||||
tags=["LLM Model Management"],
|
||||
summary="停止指定的LLM模型(Model Worker)",
|
||||
)(stop_llm_model)
|
||||
|
||||
app.post("/llm_model/change",
|
||||
tags=["LLM Model Management"],
|
||||
summary="切换指定的LLM模型(Model Worker)",
|
||||
)(change_llm_model)
|
||||
|
||||
# 服务器相关接口
|
||||
app.post("/server/configs",
|
||||
tags=["Server State"],
|
||||
|
||||
@ -21,7 +21,7 @@ from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
||||
|
||||
|
||||
def create_models_from_config(configs, callbacks, stream):
|
||||
def create_models_from_config(configs, openai_config, callbacks, stream):
|
||||
if configs is None:
|
||||
configs = {}
|
||||
models = {}
|
||||
@ -30,6 +30,9 @@ def create_models_from_config(configs, callbacks, stream):
|
||||
for model_name, params in model_configs.items():
|
||||
callbacks = callbacks if params.get('callbacks', False) else None
|
||||
model_instance = get_ChatOpenAI(
|
||||
endpoint_host=openai_config.get('endpoint_host', None),
|
||||
endpoint_host_key=openai_config.get('endpoint_host_key', None),
|
||||
endpoint_host_proxy=openai_config.get('endpoint_host_proxy', None),
|
||||
model_name=model_name,
|
||||
temperature=params.get('temperature', 0.5),
|
||||
max_tokens=params.get('max_tokens', 1000),
|
||||
@ -113,6 +116,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
),
|
||||
stream: bool = Body(True, description="流式输出"),
|
||||
model_config: Dict = Body({}, description="LLM 模型配置"),
|
||||
openai_config: Dict = Body({}, description="openaiEndpoint配置"),
|
||||
tool_config: Dict = Body({}, description="工具配置"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
@ -124,7 +128,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
|
||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, stream=stream)
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config,
|
||||
openai_config=openai_config, stream=stream)
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||
full_chain = create_models_chains(prompts=prompts,
|
||||
|
||||
@ -14,6 +14,9 @@ from server.utils import get_prompt_template
|
||||
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
@ -24,6 +27,9 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
|
||||
#TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||
async def completion_iterator(query: str,
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
model_name: str = None,
|
||||
prompt_name: str = prompt_name,
|
||||
echo: bool = echo,
|
||||
@ -34,6 +40,9 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
max_tokens = None
|
||||
|
||||
model = get_OpenAI(
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@ -63,7 +72,10 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
|
||||
await task
|
||||
|
||||
return EventSourceResponse(completion_iterator(query=query,
|
||||
return StreamingResponse(completion_iterator(query=query,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
)
|
||||
|
||||
@ -1,124 +0,0 @@
|
||||
from fastapi import Body
|
||||
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, LLM_MODEL_CONFIG
|
||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||
get_httpx_client, get_model_worker_config)
|
||||
from typing import List
|
||||
|
||||
|
||||
def list_running_models(
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
|
||||
placeholder: str = Body(None, description="该参数未使用,占位用"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
从fastchat controller获取已加载模型列表及其配置项
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
with get_httpx_client() as client:
|
||||
r = client.post(controller_address + "/list_models")
|
||||
models = r.json()["models"]
|
||||
data = {m: get_model_config(m).data for m in models}
|
||||
|
||||
## 只有LLM模型才返回
|
||||
result = {}
|
||||
for model, config in data.items():
|
||||
if model in LLM_MODEL_CONFIG['llm_model']:
|
||||
result[model] = config
|
||||
|
||||
return BaseResponse(data=result)
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
data={},
|
||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
|
||||
def list_config_models(
|
||||
types: List[str] = Body(["local", "online"], description="模型配置项类别,如local, online, worker"),
|
||||
placeholder: str = Body(None, description="占位用,无实际效果")
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
从本地获取configs中配置的模型列表
|
||||
'''
|
||||
data = {}
|
||||
result = {}
|
||||
|
||||
for type, models in list_config_llm_models().items():
|
||||
if type in types:
|
||||
data[type] = {m: get_model_config(m).data for m in models}
|
||||
|
||||
for model, config in data.items():
|
||||
if model in LLM_MODEL_CONFIG['llm_model']:
|
||||
result[type][model] = config
|
||||
|
||||
return BaseResponse(data=result)
|
||||
|
||||
|
||||
def get_model_config(
|
||||
model_name: str = Body(description="配置中LLM模型的名称"),
|
||||
placeholder: str = Body(None, description="占位用,无实际效果")
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
获取LLM模型配置项(合并后的)
|
||||
'''
|
||||
config = {}
|
||||
# 删除ONLINE_MODEL配置中的敏感信息
|
||||
for k, v in get_model_worker_config(model_name=model_name).items():
|
||||
if not (k == "worker_class"
|
||||
or "key" in k.lower()
|
||||
or "secret" in k.lower()
|
||||
or k.lower().endswith("id")):
|
||||
config[k] = v
|
||||
|
||||
return BaseResponse(data=config)
|
||||
|
||||
|
||||
def stop_llm_model(
|
||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
向fastchat controller请求停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
with get_httpx_client() as client:
|
||||
r = client.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name},
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
|
||||
def change_llm_model(
|
||||
model_name: str = Body(..., description="当前运行模型"),
|
||||
new_model_name: str = Body(..., description="要切换的新模型"),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
):
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
with get_httpx_client() as client:
|
||||
r = client.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name, "new_model_name": new_model_name},
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||
104
server/utils.py
104
server/utils.py
@ -45,6 +45,9 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
|
||||
|
||||
def get_ChatOpenAI(
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
@ -61,18 +64,21 @@ def get_ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_OpenAI(
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
@ -82,19 +88,18 @@ def get_OpenAI(
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> OpenAI:
|
||||
config = get_model_worker_config(model_name)
|
||||
if model_name == "openai-api":
|
||||
model_name = config.get("model_name")
|
||||
|
||||
# TODO: 从API获取模型信息
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||
echo=echo,
|
||||
**kwargs
|
||||
)
|
||||
@ -365,69 +370,14 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
||||
'''
|
||||
from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from server import model_workers
|
||||
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
try:
|
||||
config["worker_class"] = getattr(model_workers, provider)
|
||||
except Exception as e:
|
||||
msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
# 本地模型
|
||||
if model_name in MODEL_PATH["llm_model"]:
|
||||
path = get_model_path(model_name)
|
||||
config["model_path"] = path
|
||||
if path and os.path.isdir(path):
|
||||
config["model_path_exists"] = True
|
||||
config["device"] = llm_device(config.get("device"))
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_all_model_worker_configs() -> dict:
|
||||
result = {}
|
||||
model_names = set(FSCHAT_MODEL_WORKERS.keys())
|
||||
for name in model_names:
|
||||
if name != "default":
|
||||
result[name] = get_model_worker_config(name)
|
||||
return result
|
||||
|
||||
|
||||
def fschat_controller_address() -> str:
|
||||
from configs.server_config import FSCHAT_CONTROLLER
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str:
|
||||
if model := get_model_worker_config(model_name): # TODO: depends fastchat
|
||||
host = model["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = model["port"]
|
||||
return f"http://{host}:{port}"
|
||||
return ""
|
||||
|
||||
|
||||
def fschat_openai_api_address() -> str:
|
||||
from configs.server_config import FSCHAT_OPENAI_API
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
return f"http://{host}:{port}/v1"
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
from configs.server_config import API_SERVER
|
||||
|
||||
@ -461,6 +411,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
|
||||
def set_httpx_config(
|
||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||
proxy: Union[str, Dict] = None,
|
||||
unused_proxies: List[str] = [],
|
||||
):
|
||||
'''
|
||||
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
|
||||
@ -498,11 +449,7 @@ def set_httpx_config(
|
||||
"http://localhost",
|
||||
]
|
||||
# do not use proxy for user deployed fastchat servers
|
||||
for x in [
|
||||
fschat_controller_address(),
|
||||
fschat_model_worker_address(),
|
||||
fschat_openai_api_address(),
|
||||
]:
|
||||
for x in unused_proxies:
|
||||
host = ":".join(x.split(":")[:2])
|
||||
if host not in no_proxy:
|
||||
no_proxy.append(host)
|
||||
@ -568,6 +515,7 @@ def get_httpx_client(
|
||||
use_async: bool = False,
|
||||
proxies: Union[str, Dict] = None,
|
||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||
unused_proxies: List[str] = [],
|
||||
**kwargs,
|
||||
) -> Union[httpx.Client, httpx.AsyncClient]:
|
||||
'''
|
||||
@ -579,11 +527,7 @@ def get_httpx_client(
|
||||
"all://localhost": None,
|
||||
}
|
||||
# do not use proxy for user deployed fastchat servers
|
||||
for x in [
|
||||
fschat_controller_address(),
|
||||
fschat_model_worker_address(),
|
||||
fschat_openai_api_address(),
|
||||
]:
|
||||
for x in unused_proxies:
|
||||
host = ":".join(x.split(":")[:2])
|
||||
default_proxies.update({host: None})
|
||||
|
||||
@ -629,8 +573,6 @@ def get_server_configs() -> Dict:
|
||||
获取configs中的原始配置项,供前端使用
|
||||
'''
|
||||
_custom = {
|
||||
"controller_address": fschat_controller_address(),
|
||||
"openai_api_address": fschat_openai_api_address(),
|
||||
"api_address": api_address(),
|
||||
}
|
||||
|
||||
@ -638,14 +580,8 @@ def get_server_configs() -> Dict:
|
||||
|
||||
|
||||
def list_online_embed_models() -> List[str]:
|
||||
from server import model_workers
|
||||
|
||||
ret = []
|
||||
for k, v in list_config_llm_models()["online"].items():
|
||||
if provider := v.get("provider"):
|
||||
worker_class = getattr(model_workers, provider, None)
|
||||
if worker_class is not None and worker_class.can_embedding():
|
||||
ret.append(k)
|
||||
# TODO: 从在线API获取支持的模型列表
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
622
startup.py
622
startup.py
@ -25,17 +25,11 @@ from configs import (
|
||||
LLM_MODEL_CONFIG,
|
||||
EMBEDDING_MODEL,
|
||||
TEXT_SPLITTER_NAME,
|
||||
FSCHAT_CONTROLLER,
|
||||
FSCHAT_OPENAI_API,
|
||||
FSCHAT_MODEL_WORKERS,
|
||||
API_SERVER,
|
||||
WEBUI_SERVER,
|
||||
HTTPX_DEFAULT_TIMEOUT,
|
||||
)
|
||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, get_httpx_client,
|
||||
get_model_worker_config,
|
||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
||||
from server.utils import (FastAPI, embedding_device)
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
import argparse
|
||||
from typing import List, Dict
|
||||
@ -49,234 +43,6 @@ for model_category in LLM_MODEL_CONFIG.values():
|
||||
|
||||
all_model_names_list = list(all_model_names)
|
||||
|
||||
def create_controller_app(
|
||||
dispatch_method: str,
|
||||
log_level: str = "INFO",
|
||||
) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.controller import app, Controller, logger
|
||||
logger.setLevel(log_level)
|
||||
|
||||
controller = Controller(dispatch_method)
|
||||
sys.modules["fastchat.serve.controller"].controller = controller
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat Controller"
|
||||
app._controller = controller
|
||||
return app
|
||||
|
||||
|
||||
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
"""
|
||||
kwargs包含的字段如下:
|
||||
host:
|
||||
port:
|
||||
model_names:[`model_name`]
|
||||
controller_address:
|
||||
worker_address:
|
||||
|
||||
对于Langchain支持的模型:
|
||||
langchain_model:True
|
||||
不会使用fschat
|
||||
对于online_api:
|
||||
online_api:True
|
||||
worker_class: `provider`
|
||||
对于离线模型:
|
||||
model_path: `model_name_or_path`,huggingface的repo-id或本地路径
|
||||
device:`LLM_DEVICE`
|
||||
"""
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args([])
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
|
||||
from fastchat.serve.base_model_worker import app
|
||||
worker = ""
|
||||
# 在线模型API
|
||||
elif worker_class := kwargs.get("worker_class"):
|
||||
from fastchat.serve.base_model_worker import app
|
||||
|
||||
worker = worker_class(model_names=args.model_names,
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address)
|
||||
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
||||
|
||||
# 本地模型
|
||||
else:
|
||||
from configs.model_config import VLLM_MODEL_DICT
|
||||
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
||||
import fastchat.serve.vllm_worker
|
||||
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
|
||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
||||
args.tokenizer_mode = 'auto'
|
||||
args.trust_remote_code = True
|
||||
args.download_dir = None
|
||||
args.load_format = 'auto'
|
||||
args.dtype = 'auto'
|
||||
args.seed = 0
|
||||
args.worker_use_ray = False
|
||||
args.pipeline_parallel_size = 1
|
||||
args.tensor_parallel_size = 1
|
||||
args.block_size = 16
|
||||
args.swap_space = 4 # GiB
|
||||
args.gpu_memory_utilization = 0.90
|
||||
args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
|
||||
args.max_num_seqs = 256
|
||||
args.disable_log_stats = False
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.no_register = False
|
||||
args.num_gpus = 4 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||
args.engine_use_ray = False
|
||||
args.disable_log_requests = False
|
||||
|
||||
# 0.2.2 vllm后要加的参数, 但是这里不需要
|
||||
args.max_model_len = None
|
||||
args.revision = None
|
||||
args.quantization = None
|
||||
args.max_log_len = None
|
||||
args.tokenizer_revision = None
|
||||
|
||||
# 0.2.2 vllm需要新加的参数
|
||||
args.max_paddings = 256
|
||||
|
||||
if args.model_path:
|
||||
args.model = args.model_path
|
||||
if args.num_gpus > 1:
|
||||
args.tensor_parallel_size = args.num_gpus
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
worker = VLLMWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
llm_engine=engine,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
||||
|
||||
else:
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||
|
||||
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
||||
args.max_gpu_memory = "22GiB"
|
||||
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
||||
|
||||
args.load_8bit = False
|
||||
args.cpu_offloading = None
|
||||
args.gptq_ckpt = None
|
||||
args.gptq_wbits = 16
|
||||
args.gptq_groupsize = -1
|
||||
args.gptq_act_order = False
|
||||
args.awq_ckpt = None
|
||||
args.awq_wbits = 16
|
||||
args.awq_groupsize = -1
|
||||
args.model_names = [""]
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.stream_interval = 2
|
||||
args.no_register = False
|
||||
args.embed_in_truncate = False
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
if args.gpus:
|
||||
if args.num_gpus is None:
|
||||
args.num_gpus = len(args.gpus.split(','))
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
embed_in_truncate=args.embed_in_truncate,
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
# sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
||||
app._worker = worker
|
||||
return app
|
||||
|
||||
|
||||
def create_openai_api_app(
|
||||
controller_address: str,
|
||||
api_keys: List = [],
|
||||
log_level: str = "INFO",
|
||||
) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
||||
from fastchat.utils import build_logger
|
||||
logger = build_logger("openai_api", "openai_api.log")
|
||||
logger.setLevel(log_level)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_credentials=True,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
sys.modules["fastchat.serve.openai_api_server"].logger = logger
|
||||
app_settings.controller_address = controller_address
|
||||
app_settings.api_keys = api_keys
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat OpeanAI API Server"
|
||||
return app
|
||||
|
||||
|
||||
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||
@app.on_event("startup")
|
||||
@ -285,154 +51,6 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||
started_event.set()
|
||||
|
||||
|
||||
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
||||
import uvicorn
|
||||
import httpx
|
||||
from fastapi import Body
|
||||
import time
|
||||
import sys
|
||||
from server.utils import set_httpx_config
|
||||
set_httpx_config()
|
||||
|
||||
app = create_controller_app(
|
||||
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
|
||||
log_level=log_level,
|
||||
)
|
||||
_set_app_event(app, started_event)
|
||||
|
||||
# add interface to release and load model worker
|
||||
@app.post("/release_worker")
|
||||
def release_worker(
|
||||
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
|
||||
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
) -> Dict:
|
||||
available_models = app._controller.list_models()
|
||||
if new_model_name in available_models:
|
||||
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
|
||||
logger.info(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
if new_model_name:
|
||||
logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}")
|
||||
else:
|
||||
logger.info(f"即将停止LLM模型: {model_name}")
|
||||
|
||||
if model_name not in available_models:
|
||||
msg = f"the model {model_name} is not available"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
worker_address = app._controller.get_worker_address(model_name)
|
||||
if not worker_address:
|
||||
msg = f"can not find model_worker address for {model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
with get_httpx_client() as client:
|
||||
r = client.post(worker_address + "/release",
|
||||
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
|
||||
if r.status_code != 200:
|
||||
msg = f"failed to release model: {model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
if new_model_name:
|
||||
timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
|
||||
while timer > 0:
|
||||
models = app._controller.list_models()
|
||||
if new_model_name in models:
|
||||
break
|
||||
time.sleep(1)
|
||||
timer -= 1
|
||||
if timer > 0:
|
||||
msg = f"sucess change model from {model_name} to {new_model_name}"
|
||||
logger.info(msg)
|
||||
return {"code": 200, "msg": msg}
|
||||
else:
|
||||
msg = f"failed change model from {model_name} to {new_model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
else:
|
||||
msg = f"sucess to release model: {model_name}"
|
||||
logger.info(msg)
|
||||
return {"code": 200, "msg": msg}
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
|
||||
if log_level == "ERROR":
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||
|
||||
|
||||
def run_model_worker(
|
||||
model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])),
|
||||
controller_address: str = "",
|
||||
log_level: str = "INFO",
|
||||
q: mp.Queue = None,
|
||||
started_event: mp.Event = None,
|
||||
):
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
import sys
|
||||
from server.utils import set_httpx_config
|
||||
set_httpx_config()
|
||||
|
||||
kwargs = get_model_worker_config(model_name)
|
||||
host = kwargs.pop("host")
|
||||
port = kwargs.pop("port")
|
||||
kwargs["model_names"] = [model_name]
|
||||
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
||||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
||||
model_path = kwargs.get("model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||
_set_app_event(app, started_event)
|
||||
if log_level == "ERROR":
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
# add interface to release and load model
|
||||
@app.post("/release")
|
||||
def release_model(
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
) -> Dict:
|
||||
if keep_origin:
|
||||
if new_model_name:
|
||||
q.put([model_name, "start", new_model_name])
|
||||
else:
|
||||
if new_model_name:
|
||||
q.put([model_name, "replace", new_model_name])
|
||||
else:
|
||||
q.put([model_name, "stop", None])
|
||||
return {"code": 200, "msg": "done"}
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||
|
||||
|
||||
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
|
||||
import uvicorn
|
||||
import sys
|
||||
from server.utils import set_httpx_config
|
||||
set_httpx_config()
|
||||
|
||||
controller_addr = fschat_controller_address()
|
||||
app = create_openai_api_app(controller_addr, log_level=log_level)
|
||||
_set_app_event(app, started_event)
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
if log_level == "ERROR":
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||
from server.api import create_app
|
||||
import uvicorn
|
||||
@ -473,6 +91,18 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None):
|
||||
p.wait()
|
||||
|
||||
|
||||
def run_loom(started_event: mp.Event = None):
|
||||
from configs import LOOM_CONFIG
|
||||
|
||||
cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local",
|
||||
"-f", LOOM_CONFIG
|
||||
]
|
||||
|
||||
p = subprocess.Popen(cmd)
|
||||
started_event.set()
|
||||
p.wait()
|
||||
|
||||
|
||||
def parse_args() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -488,57 +118,14 @@ def parse_args() -> argparse.ArgumentParser:
|
||||
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
|
||||
dest="all_api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-api",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/openai_api/model_worker servers",
|
||||
dest="llm_api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--openai-api",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/openai_api servers",
|
||||
dest="openai_api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-worker",
|
||||
action="store_true",
|
||||
help="run fastchat's model_worker server with specified model name. "
|
||||
"specify --model-name if not using default llm models",
|
||||
dest="model_worker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--model-name",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=all_model_names_list,
|
||||
help="specify model name for model worker. "
|
||||
"add addition names with space seperated to start multiple model workers.",
|
||||
dest="model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--controller",
|
||||
type=str,
|
||||
help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
|
||||
dest="controller_address",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--api",
|
||||
action="store_true",
|
||||
help="run api.py server",
|
||||
dest="api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--api-worker",
|
||||
action="store_true",
|
||||
help="run online model api such as zhipuai",
|
||||
dest="api_worker",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--webui",
|
||||
@ -546,13 +133,6 @@ def parse_args() -> argparse.ArgumentParser:
|
||||
help="run webui.py server",
|
||||
dest="webui",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help="减少fastchat服务log信息",
|
||||
dest="quiet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--lite",
|
||||
@ -578,24 +158,13 @@ def dump_server_info(after_start=False, args=None):
|
||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||
print("\n")
|
||||
|
||||
models = list(LLM_MODEL_CONFIG['llm_model'].keys())
|
||||
if args and args.model_name:
|
||||
models = args.model_name
|
||||
|
||||
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
||||
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
||||
|
||||
for model in models:
|
||||
pprint(get_model_worker_config(model))
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||
|
||||
if after_start:
|
||||
print("\n")
|
||||
print(f"服务端运行信息:")
|
||||
if args.openai_api:
|
||||
print(f" OpenAI API Server: {fschat_openai_api_address()}")
|
||||
if args.api:
|
||||
print(f" Chatchat API Server: {api_address()}")
|
||||
if args.webui:
|
||||
print(f" Chatchat WEBUI Server: {webui_address()}")
|
||||
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
||||
@ -625,7 +194,6 @@ async def start_main_server():
|
||||
manager = mp.Manager()
|
||||
run_mode = None
|
||||
|
||||
queue = manager.Queue()
|
||||
args, parser = parse_args()
|
||||
|
||||
if args.all_webui:
|
||||
@ -662,69 +230,16 @@ async def start_main_server():
|
||||
processes = {"online_api": {}, "model_worker": {}}
|
||||
|
||||
def process_count():
|
||||
return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2
|
||||
|
||||
if args.quiet or not log_verbose:
|
||||
log_level = "ERROR"
|
||||
else:
|
||||
log_level = "INFO"
|
||||
|
||||
controller_started = manager.Event()
|
||||
if args.openai_api:
|
||||
process = Process(
|
||||
target=run_controller,
|
||||
name=f"controller",
|
||||
kwargs=dict(log_level=log_level, started_event=controller_started),
|
||||
daemon=True,
|
||||
)
|
||||
processes["controller"] = process
|
||||
|
||||
process = Process(
|
||||
target=run_openai_api,
|
||||
name=f"openai_api",
|
||||
daemon=True,
|
||||
)
|
||||
processes["openai_api"] = process
|
||||
|
||||
model_worker_started = []
|
||||
if args.model_worker:
|
||||
for model_name in args.model_name:
|
||||
config = get_model_worker_config(model_name)
|
||||
if not config.get("online_api"):
|
||||
e = manager.Event()
|
||||
model_worker_started.append(e)
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {model_name}",
|
||||
kwargs=dict(model_name=model_name,
|
||||
controller_address=args.controller_address,
|
||||
log_level=log_level,
|
||||
q=queue,
|
||||
started_event=e),
|
||||
daemon=True,
|
||||
)
|
||||
processes["model_worker"][model_name] = process
|
||||
|
||||
if args.api_worker:
|
||||
for model_name in args.model_name:
|
||||
config = get_model_worker_config(model_name)
|
||||
if (config.get("online_api")
|
||||
and config.get("worker_class")
|
||||
and model_name in FSCHAT_MODEL_WORKERS):
|
||||
e = manager.Event()
|
||||
model_worker_started.append(e)
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"api_worker - {model_name}",
|
||||
kwargs=dict(model_name=model_name,
|
||||
controller_address=args.controller_address,
|
||||
log_level=log_level,
|
||||
q=queue,
|
||||
started_event=e),
|
||||
daemon=True,
|
||||
)
|
||||
processes["online_api"][model_name] = process
|
||||
return len(processes)
|
||||
|
||||
loom_started = manager.Event()
|
||||
process = Process(
|
||||
target=run_loom,
|
||||
name=f"run_loom Server",
|
||||
kwargs=dict(started_event=loom_started),
|
||||
daemon=True,
|
||||
)
|
||||
processes["run_loom"] = process
|
||||
api_started = manager.Event()
|
||||
if args.api:
|
||||
process = Process(
|
||||
@ -750,25 +265,10 @@ async def start_main_server():
|
||||
else:
|
||||
try:
|
||||
# 保证任务收到SIGINT后,能够正常退出
|
||||
if p := processes.get("controller"):
|
||||
if p := processes.get("run_loom"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
controller_started.wait() # 等待controller启动完成
|
||||
|
||||
if p := processes.get("openai_api"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
for n, p in processes.get("model_worker", {}).items():
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
for n, p in processes.get("online_api", []).items():
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
for e in model_worker_started:
|
||||
e.wait()
|
||||
loom_started.wait() # 等待Loom启动完成
|
||||
|
||||
if p := processes.get("api"):
|
||||
p.start()
|
||||
@ -782,74 +282,10 @@ async def start_main_server():
|
||||
|
||||
dump_server_info(after_start=True, args=args)
|
||||
|
||||
while True:
|
||||
cmd = queue.get() # 收到切换模型的消息
|
||||
e = manager.Event()
|
||||
if isinstance(cmd, list):
|
||||
model_name, cmd, new_model_name = cmd
|
||||
if cmd == "start": # 运行新模型
|
||||
logger.info(f"准备启动新模型进程:{new_model_name}")
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {new_model_name}",
|
||||
kwargs=dict(model_name=new_model_name,
|
||||
controller_address=args.controller_address,
|
||||
log_level=log_level,
|
||||
q=queue,
|
||||
started_event=e),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
process.name = f"{process.name} ({process.pid})"
|
||||
processes["model_worker"][new_model_name] = process
|
||||
e.wait()
|
||||
logger.info(f"成功启动新模型进程:{new_model_name}")
|
||||
elif cmd == "stop":
|
||||
if process := processes["model_worker"].get(model_name):
|
||||
time.sleep(1)
|
||||
process.terminate()
|
||||
process.join()
|
||||
logger.info(f"停止模型进程:{model_name}")
|
||||
else:
|
||||
logger.error(f"未找到模型进程:{model_name}")
|
||||
elif cmd == "replace":
|
||||
if process := processes["model_worker"].pop(model_name, None):
|
||||
logger.info(f"停止模型进程:{model_name}")
|
||||
start_time = datetime.now()
|
||||
time.sleep(1)
|
||||
process.terminate()
|
||||
process.join()
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {new_model_name}",
|
||||
kwargs=dict(model_name=new_model_name,
|
||||
controller_address=args.controller_address,
|
||||
log_level=log_level,
|
||||
q=queue,
|
||||
started_event=e),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
process.name = f"{process.name} ({process.pid})"
|
||||
processes["model_worker"][new_model_name] = process
|
||||
e.wait()
|
||||
timing = datetime.now() - start_time
|
||||
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。")
|
||||
else:
|
||||
logger.error(f"未找到模型进程:{model_name}")
|
||||
# 等待所有进程退出
|
||||
if p := processes.get("webui"):
|
||||
|
||||
# for process in processes.get("model_worker", {}).values():
|
||||
# process.join()
|
||||
# for process in processes.get("online_api", {}).values():
|
||||
# process.join()
|
||||
|
||||
# for name, process in processes.items():
|
||||
# if name not in ["model_worker", "online_api"]:
|
||||
# if isinstance(p, dict):
|
||||
# for work_process in p.values():
|
||||
# work_process.join()
|
||||
# else:
|
||||
# process.join()
|
||||
p.join()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
||||
|
||||
@ -13,21 +13,6 @@ from typing import List
|
||||
api = ApiRequest()
|
||||
|
||||
|
||||
def test_get_default_llm():
|
||||
llm = api.get_default_llm_model()
|
||||
|
||||
print(llm)
|
||||
assert isinstance(llm, tuple)
|
||||
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
|
||||
|
||||
|
||||
def test_server_configs():
|
||||
configs = api.get_server_configs()
|
||||
pprint(configs, depth=2)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
assert len(configs) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["llm_chat"])
|
||||
def test_get_prompt_template(type):
|
||||
|
||||
@ -8,13 +8,13 @@ from server.model_workers.base import *
|
||||
from server.utils import get_model_worker_config, list_config_llm_models
|
||||
from pprint import pprint
|
||||
import pytest
|
||||
|
||||
|
||||
workers = []
|
||||
for x in list_config_llm_models()["online"]:
|
||||
if x in ONLINE_LLM_MODEL and x not in workers:
|
||||
workers.append(x)
|
||||
print(f"all workers to test: {workers}")
|
||||
#
|
||||
#
|
||||
# workers = []
|
||||
# for x in list_config_llm_models()["online"]:
|
||||
# if x in ONLINE_LLM_MODEL and x not in workers:
|
||||
# workers.append(x)
|
||||
# print(f"all workers to test: {workers}")
|
||||
|
||||
# workers = ["fangzhou-api"]
|
||||
|
||||
@ -56,4 +56,4 @@ def test_embeddings(worker):
|
||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||
assert isinstance(embeddings[0][0], float)
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
|
||||
25
webui.py
25
webui.py
@ -1,4 +1,6 @@
|
||||
import streamlit as st
|
||||
|
||||
from webui_pages.openai_plugins import openai_plugins_page
|
||||
from webui_pages.utils import *
|
||||
from streamlit_option_menu import option_menu
|
||||
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||
@ -22,9 +24,26 @@ if __name__ == "__main__":
|
||||
'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat',
|
||||
'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues",
|
||||
'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!"""
|
||||
}
|
||||
},
|
||||
layout="wide"
|
||||
|
||||
)
|
||||
|
||||
# use the following code to set the app to wide mode and the html markdown to increase the sidebar width
|
||||
st.markdown(
|
||||
"""
|
||||
<style>
|
||||
[data-testid="stSidebar"][aria-expanded="true"] > div:first-child{
|
||||
width: 350px;
|
||||
}
|
||||
[data-testid="stSidebar"][aria-expanded="false"] > div:first-child{
|
||||
width: 600px;
|
||||
margin-left: -600px;
|
||||
}
|
||||
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
pages = {
|
||||
"对话": {
|
||||
"icon": "chat",
|
||||
@ -34,6 +53,10 @@ if __name__ == "__main__":
|
||||
"icon": "hdd-stack",
|
||||
"func": knowledge_base_page,
|
||||
},
|
||||
"模型服务": {
|
||||
"icon": "hdd-stack",
|
||||
"func": openai_plugins_page,
|
||||
},
|
||||
}
|
||||
|
||||
with st.sidebar:
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import base64
|
||||
|
||||
import streamlit as st
|
||||
from streamlit_antd_components.utils import ParseItems
|
||||
|
||||
from webui_pages.dialogue.utils import process_files
|
||||
from webui_pages.loom_view_client import build_plugins_name, find_menu_items_by_index, set_llm_select, \
|
||||
get_select_model_endpoint
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from streamlit_modal import Modal
|
||||
@ -10,12 +13,15 @@ from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
|
||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG, OPENAI_KEY, OPENAI_PROXY)
|
||||
from server.callback_handler.agent_callback_handler import AgentStatus
|
||||
from server.utils import MsgType
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
import streamlit_antd_components as sac
|
||||
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
"img",
|
||||
@ -105,14 +111,8 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.session_state.setdefault("conversation_ids", {})
|
||||
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
||||
st.session_state.setdefault("file_chat_id", None)
|
||||
default_model = api.get_default_llm_model()[0]
|
||||
|
||||
if not chat_box.chat_inited:
|
||||
st.toast(
|
||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
||||
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
||||
)
|
||||
chat_box.init_session()
|
||||
st.session_state.setdefault("select_plugins_info", None)
|
||||
st.session_state.setdefault("select_model_worker", None)
|
||||
|
||||
# 弹出自定义命令帮助信息
|
||||
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
||||
@ -131,57 +131,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
chat_box.use_chat_name(conversation_name)
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
|
||||
# def on_mode_change():
|
||||
# mode = st.session_state.dialogue_mode
|
||||
# text = f"已切换到 {mode} 模式。"
|
||||
# st.toast(text)
|
||||
with st.expander("模型选择"):
|
||||
plugins_menu = build_plugins_name()
|
||||
|
||||
# dialogue_modes = ["智能对话",
|
||||
# "文件对话",
|
||||
# ]
|
||||
# dialogue_mode = st.selectbox("请选择对话模式:",
|
||||
# dialogue_modes,
|
||||
# index=0,
|
||||
# on_change=on_mode_change,
|
||||
# key="dialogue_mode",
|
||||
# )
|
||||
items, _ = ParseItems(plugins_menu).multi()
|
||||
|
||||
def on_llm_change():
|
||||
if llm_model:
|
||||
config = api.get_model_config(llm_model)
|
||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
||||
if len(plugins_menu) > 0:
|
||||
|
||||
def llm_model_format_func(x):
|
||||
if x in running_models:
|
||||
return f"{x} (Running)"
|
||||
return x
|
||||
|
||||
running_models = list(api.list_running_models())
|
||||
available_models = []
|
||||
config_models = api.list_config_models()
|
||||
if not is_lite:
|
||||
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
|
||||
if (v.get("model_path_exists")
|
||||
and k not in running_models):
|
||||
available_models.append(k)
|
||||
for k, v in config_models.get("online", {}).items():
|
||||
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
||||
available_models.append(k)
|
||||
llm_models = running_models + available_models
|
||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
||||
if cur_llm_model in llm_models:
|
||||
index = llm_models.index(cur_llm_model)
|
||||
else:
|
||||
index = 0
|
||||
llm_model = st.selectbox("选择LLM模型",
|
||||
llm_models,
|
||||
index,
|
||||
format_func=llm_model_format_func,
|
||||
on_change=on_llm_change,
|
||||
key="llm_model",
|
||||
)
|
||||
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True)
|
||||
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
|
||||
set_llm_select(plugins_info, llm_model_worker)
|
||||
else:
|
||||
st.info("没有可用的插件")
|
||||
|
||||
# 传入后端的内容
|
||||
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
@ -213,8 +174,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
if is_selected:
|
||||
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
||||
|
||||
llm_model = None
|
||||
if st.session_state["select_model_worker"] is not None:
|
||||
llm_model = st.session_state["select_model_worker"]['label']
|
||||
|
||||
if llm_model is not None:
|
||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||
|
||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||
@ -258,7 +223,8 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.rerun()
|
||||
else:
|
||||
history = get_messages_history(
|
||||
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
||||
model_config["llm_model"].get(next(iter(model_config["llm_model"])), {}).get("history_len", 1)
|
||||
)
|
||||
chat_box.user_say(prompt)
|
||||
if files_upload:
|
||||
if files_upload["images"]:
|
||||
@ -277,10 +243,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
text = ""
|
||||
text_action = ""
|
||||
element_index = 0
|
||||
|
||||
openai_config = {}
|
||||
endpoint_host, select_model_name = get_select_model_endpoint()
|
||||
openai_config["endpoint_host"] = endpoint_host
|
||||
openai_config["model_name"] = select_model_name
|
||||
openai_config["endpoint_host_key"] = OPENAI_KEY
|
||||
openai_config["endpoint_host_proxy"] = OPENAI_PROXY
|
||||
for d in api.chat_chat(query=prompt,
|
||||
metadata=files_upload,
|
||||
history=history,
|
||||
model_config=model_config,
|
||||
openai_config=openai_config,
|
||||
conversation_id=conversation_id,
|
||||
tool_config=selected_tool_configs,
|
||||
):
|
||||
|
||||
@ -3,6 +3,7 @@ import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def encode_file_to_base64(file):
|
||||
# 将文件内容转换为 Base64 编码
|
||||
buffer = BytesIO()
|
||||
|
||||
138
webui_pages/loom_view_client.py
Normal file
138
webui_pages/loom_view_client.py
Normal file
@ -0,0 +1,138 @@
|
||||
from typing import Tuple, Any
|
||||
|
||||
import streamlit as st
|
||||
from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient
|
||||
|
||||
client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False)
|
||||
|
||||
|
||||
def update_store():
|
||||
st.session_state.status = client.status()
|
||||
list_plugins = client.list_plugins()
|
||||
st.session_state.run_plugins_list = list_plugins.get("plugins_list", [])
|
||||
launch_subscribe_info = {}
|
||||
for plugin_name in st.session_state.run_plugins_list:
|
||||
launch_subscribe_info[plugin_name] = client.launch_subscribe_info(plugin_name)
|
||||
|
||||
st.session_state.launch_subscribe_info = launch_subscribe_info
|
||||
list_running_models = {}
|
||||
for plugin_name in st.session_state.run_plugins_list:
|
||||
list_running_models[plugin_name] = client.list_running_models(plugin_name)
|
||||
st.session_state.list_running_models = list_running_models
|
||||
|
||||
model_config = {}
|
||||
for plugin_name in st.session_state.run_plugins_list:
|
||||
model_config[plugin_name] = client.list_llm_models(plugin_name)
|
||||
st.session_state.model_config = model_config
|
||||
|
||||
|
||||
def start_plugin():
|
||||
import time
|
||||
start_plugins_name = st.session_state.plugins_name
|
||||
if start_plugins_name in st.session_state.run_plugins_list:
|
||||
st.toast(start_plugins_name + " has already been counted.")
|
||||
|
||||
time.sleep(.5)
|
||||
else:
|
||||
|
||||
st.toast("start_plugin " + start_plugins_name + ",starting.")
|
||||
result = client.launch_subscribe(start_plugins_name)
|
||||
st.toast("start_plugin "+start_plugins_name + " ." + result.get("detail", ""))
|
||||
time.sleep(3)
|
||||
result1 = client.launch_subscribe_start(start_plugins_name)
|
||||
|
||||
st.toast("start_plugin "+start_plugins_name + " ." + result1.get("detail", ""))
|
||||
time.sleep(2)
|
||||
update_store()
|
||||
|
||||
|
||||
def start_worker():
|
||||
select_plugins_name = st.session_state.plugins_name
|
||||
select_worker_id = st.session_state.worker_id
|
||||
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
|
||||
already_counted = False
|
||||
for model in start_model_list:
|
||||
if model['worker_id'] == select_worker_id:
|
||||
already_counted = True
|
||||
break
|
||||
|
||||
if already_counted:
|
||||
st.toast(
|
||||
"select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has already been counted.")
|
||||
import time
|
||||
time.sleep(.5)
|
||||
else:
|
||||
|
||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " starting.")
|
||||
result = client.launch_subscribe_start_model(select_plugins_name, select_worker_id)
|
||||
st.toast("start worker_id " + select_worker_id + " ." + result.get("detail", ""))
|
||||
import time
|
||||
time.sleep(.5)
|
||||
update_store()
|
||||
|
||||
|
||||
def stop_worker():
|
||||
select_plugins_name = st.session_state.plugins_name
|
||||
select_worker_id = st.session_state.worker_id
|
||||
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
|
||||
already_counted = False
|
||||
for model in start_model_list:
|
||||
if model['worker_id'] == select_worker_id:
|
||||
already_counted = True
|
||||
break
|
||||
|
||||
if not already_counted:
|
||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has bad already")
|
||||
import time
|
||||
time.sleep(.5)
|
||||
else:
|
||||
|
||||
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " stopping.")
|
||||
result = client.launch_subscribe_stop_model(select_plugins_name, select_worker_id)
|
||||
st.toast("stop worker_id " + select_worker_id + " ." + result.get("detail", ""))
|
||||
import time
|
||||
time.sleep(.5)
|
||||
update_store()
|
||||
|
||||
|
||||
def build_plugins_name():
|
||||
import streamlit_antd_components as sac
|
||||
if "run_plugins_list" not in st.session_state:
|
||||
return []
|
||||
# 按照模型构建sac.menu(菜单
|
||||
menu_items = []
|
||||
for key, value in st.session_state.list_running_models.items():
|
||||
menu_item_children = []
|
||||
for model in value:
|
||||
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
||||
|
||||
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
|
||||
|
||||
return menu_items
|
||||
|
||||
|
||||
def find_menu_items_by_index(menu_items, key):
|
||||
for menu_item in menu_items:
|
||||
if menu_item.get('children') is not None:
|
||||
for child in menu_item.get('children'):
|
||||
if child.get('key') == key:
|
||||
return menu_item, child
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def set_llm_select(plugins_info, llm_model_worker):
|
||||
st.session_state["select_plugins_info"] = plugins_info
|
||||
st.session_state["select_model_worker"] = llm_model_worker
|
||||
|
||||
|
||||
def get_select_model_endpoint() -> Tuple[str, str]:
|
||||
plugins_info = st.session_state["select_plugins_info"]
|
||||
llm_model_worker = st.session_state["select_model_worker"]
|
||||
if plugins_info is None or llm_model_worker is None:
|
||||
raise ValueError("plugins_info or llm_model_worker is None")
|
||||
plugins_name = st.session_state["select_plugins_info"]['label']
|
||||
select_model_name = st.session_state["select_model_worker"]['label']
|
||||
adapter_description = st.session_state.launch_subscribe_info[plugins_name]
|
||||
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
|
||||
return endpoint_host, select_model_name
|
||||
1
webui_pages/openai_plugins/__init__.py
Normal file
1
webui_pages/openai_plugins/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .base import openai_plugins_page
|
||||
75
webui_pages/openai_plugins/base.py
Normal file
75
webui_pages/openai_plugins/base.py
Normal file
@ -0,0 +1,75 @@
|
||||
import streamlit as st
|
||||
from loom_openai_plugins_frontend import loom_openai_plugins_frontend
|
||||
|
||||
from webui_pages.utils import ApiRequest
|
||||
from webui_pages.loom_view_client import (
|
||||
update_store,
|
||||
start_plugin,
|
||||
start_worker,
|
||||
stop_worker,
|
||||
)
|
||||
|
||||
|
||||
def openai_plugins_page(api: ApiRequest, is_lite: bool = None):
|
||||
if "status" not in st.session_state \
|
||||
or "run_plugins_list" not in st.session_state \
|
||||
or "launch_subscribe_info" not in st.session_state \
|
||||
or "list_running_models" not in st.session_state \
|
||||
or "model_config" not in st.session_state:
|
||||
update_store()
|
||||
|
||||
with (st.container()):
|
||||
|
||||
if "worker_id" not in st.session_state:
|
||||
st.session_state.worker_id = ''
|
||||
if "plugins_name" not in st.session_state and "status" in st.session_state:
|
||||
|
||||
for k, v in st.session_state.status.get("status", {}).get("subscribe_status", []).items():
|
||||
st.session_state.plugins_name = v.get("plugins_names", [])[0]
|
||||
break
|
||||
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
|
||||
with col1:
|
||||
event = loom_openai_plugins_frontend(plugins_status=st.session_state.status,
|
||||
run_list_plugins=st.session_state.run_plugins_list,
|
||||
launch_subscribe_info=st.session_state.launch_subscribe_info,
|
||||
list_running_models=st.session_state.list_running_models,
|
||||
model_config=st.session_state.model_config)
|
||||
|
||||
with col2:
|
||||
st.write("操作")
|
||||
if not st.session_state.run_plugins_list:
|
||||
button_type_disabled = False
|
||||
button_start_text = '启动'
|
||||
else:
|
||||
button_type_disabled = True
|
||||
button_start_text = '已启动'
|
||||
|
||||
if event:
|
||||
event_type = event.get("event")
|
||||
if event_type == "BottomNavigationAction":
|
||||
st.session_state.plugins_name = event.get("data")
|
||||
st.session_state.worker_id = ''
|
||||
# 不存在run_plugins_list,打开启动按钮
|
||||
if st.session_state.plugins_name not in st.session_state.run_plugins_list \
|
||||
or st.session_state.run_plugins_list:
|
||||
button_type_disabled = False
|
||||
button_start_text = '启动'
|
||||
else:
|
||||
button_type_disabled = True
|
||||
button_start_text = '已启动'
|
||||
if event_type == "CardCoverComponent":
|
||||
st.session_state.worker_id = event.get("data")
|
||||
|
||||
st.button(button_start_text, disabled=button_type_disabled, key="start",
|
||||
on_click=start_plugin)
|
||||
|
||||
if st.session_state.worker_id != '':
|
||||
st.button("启动" + st.session_state.worker_id, key="start_worker",
|
||||
on_click=start_worker)
|
||||
st.button("停止" + st.session_state.worker_id, key="stop_worker",
|
||||
on_click=stop_worker)
|
||||
|
||||
|
||||
|
||||
@ -266,6 +266,7 @@ class ApiRequest:
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model_config: Dict = None,
|
||||
openai_config: Dict = None,
|
||||
tool_config: Dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -280,6 +281,7 @@ class ApiRequest:
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_config": model_config,
|
||||
"openai_config": openai_config,
|
||||
"tool_config": tool_config,
|
||||
}
|
||||
|
||||
@ -619,219 +621,6 @@ class ApiRequest:
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
# LLM模型相关操作
|
||||
def list_running_models(
|
||||
self,
|
||||
controller_address: str = None,
|
||||
):
|
||||
'''
|
||||
获取Fastchat中正运行的模型列表
|
||||
'''
|
||||
data = {
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
|
||||
response = self.post(
|
||||
"/llm_model/list_running_models",
|
||||
json=data,
|
||||
)
|
||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
|
||||
|
||||
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
||||
'''
|
||||
从服务器上获取当前运行的LLM模型。
|
||||
当 local_first=True 时,优先返回运行中的本地模型,否则优先按 LLM_MODEL_CONFIG['llm_model']配置顺序返回。
|
||||
返回类型为(model_name, is_local_model)
|
||||
'''
|
||||
|
||||
def ret_sync():
|
||||
running_models = self.list_running_models()
|
||||
if not running_models:
|
||||
return "", False
|
||||
|
||||
model = ""
|
||||
for m in LLM_MODEL_CONFIG['llm_model']:
|
||||
if m not in running_models:
|
||||
continue
|
||||
is_local = not running_models[m].get("online_api")
|
||||
if local_first and not is_local:
|
||||
continue
|
||||
else:
|
||||
model = m
|
||||
break
|
||||
|
||||
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
|
||||
model = list(running_models)[0]
|
||||
is_local = not running_models[model].get("online_api")
|
||||
return model, is_local
|
||||
|
||||
async def ret_async():
|
||||
running_models = await self.list_running_models()
|
||||
if not running_models:
|
||||
return "", False
|
||||
|
||||
model = ""
|
||||
for m in LLM_MODEL_CONFIG['llm_model']:
|
||||
if m not in running_models:
|
||||
continue
|
||||
is_local = not running_models[m].get("online_api")
|
||||
if local_first and not is_local:
|
||||
continue
|
||||
else:
|
||||
model = m
|
||||
break
|
||||
|
||||
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
|
||||
model = list(running_models)[0]
|
||||
is_local = not running_models[model].get("online_api")
|
||||
return model, is_local
|
||||
|
||||
if self._use_async:
|
||||
return ret_async()
|
||||
else:
|
||||
return ret_sync()
|
||||
|
||||
def list_config_models(
|
||||
self,
|
||||
types: List[str] = ["local", "online"],
|
||||
) -> Dict[str, Dict]:
|
||||
'''
|
||||
获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。
|
||||
'''
|
||||
data = {
|
||||
"types": types,
|
||||
}
|
||||
response = self.post(
|
||||
"/llm_model/list_config_models",
|
||||
json=data,
|
||||
)
|
||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||
|
||||
def get_model_config(
|
||||
self,
|
||||
model_name: str = None,
|
||||
) -> Dict:
|
||||
'''
|
||||
获取服务器上模型配置
|
||||
'''
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
}
|
||||
response = self.post(
|
||||
"/llm_model/get_model_config",
|
||||
json=data,
|
||||
)
|
||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||
|
||||
def stop_llm_model(
|
||||
self,
|
||||
model_name: str,
|
||||
controller_address: str = None,
|
||||
):
|
||||
'''
|
||||
停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
|
||||
response = self.post(
|
||||
"/llm_model/stop",
|
||||
json=data,
|
||||
)
|
||||
return self._get_response_value(response, as_json=True)
|
||||
|
||||
def change_llm_model(
|
||||
self,
|
||||
model_name: str,
|
||||
new_model_name: str,
|
||||
controller_address: str = None,
|
||||
):
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
if not model_name or not new_model_name:
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"未指定模型名称"
|
||||
}
|
||||
|
||||
def ret_sync():
|
||||
running_models = self.list_running_models()
|
||||
if new_model_name == model_name or new_model_name in running_models:
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "无需切换"
|
||||
}
|
||||
|
||||
if model_name not in running_models:
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
|
||||
}
|
||||
|
||||
config_models = self.list_config_models()
|
||||
if new_model_name not in config_models.get("local", {}):
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"new_model_name": new_model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
|
||||
response = self.post(
|
||||
"/llm_model/change",
|
||||
json=data,
|
||||
)
|
||||
return self._get_response_value(response, as_json=True)
|
||||
|
||||
async def ret_async():
|
||||
running_models = await self.list_running_models()
|
||||
if new_model_name == model_name or new_model_name in running_models:
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "无需切换"
|
||||
}
|
||||
|
||||
if model_name not in running_models:
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
|
||||
}
|
||||
|
||||
config_models = await self.list_config_models()
|
||||
if new_model_name not in config_models.get("local", {}):
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"new_model_name": new_model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
|
||||
response = self.post(
|
||||
"/llm_model/change",
|
||||
json=data,
|
||||
)
|
||||
return self._get_response_value(response, as_json=True)
|
||||
|
||||
if self._use_async:
|
||||
return ret_async()
|
||||
else:
|
||||
return ret_sync()
|
||||
|
||||
def embed_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user