mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 16:10:18 +08:00
集成openai plugins插件
This commit is contained in:
parent
8063aab7a1
commit
48fb6b83fd
@ -8,7 +8,7 @@ torchaudio>=2.1.2
|
|||||||
# Langchain 0.1.x requirements
|
# Langchain 0.1.x requirements
|
||||||
langchain>=0.1.0
|
langchain>=0.1.0
|
||||||
langchain_openai>=0.0.2
|
langchain_openai>=0.0.2
|
||||||
langchain-community>=1.0.0
|
langchain-community>=0.0.11
|
||||||
langchainhub>=0.1.14
|
langchainhub>=0.1.14
|
||||||
|
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
|
|||||||
@ -17,9 +17,7 @@ from server.chat.chat import chat
|
|||||||
from server.chat.completion import completion
|
from server.chat.completion import completion
|
||||||
from server.chat.feedback import chat_feedback
|
from server.chat.feedback import chat_feedback
|
||||||
from server.embeddings_api import embed_texts_endpoint
|
from server.embeddings_api import embed_texts_endpoint
|
||||||
from server.llm_api import (list_running_models, list_config_models,
|
|
||||||
change_llm_model, stop_llm_model,
|
|
||||||
get_model_config)
|
|
||||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||||
get_server_configs, get_prompt_template)
|
get_server_configs, get_prompt_template)
|
||||||
from typing import List, Literal
|
from typing import List, Literal
|
||||||
@ -73,32 +71,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
|||||||
# 摘要相关接口
|
# 摘要相关接口
|
||||||
mount_filename_summary_routes(app)
|
mount_filename_summary_routes(app)
|
||||||
|
|
||||||
# LLM模型相关接口
|
|
||||||
app.post("/llm_model/list_running_models",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="列出当前已加载的模型",
|
|
||||||
)(list_running_models)
|
|
||||||
|
|
||||||
app.post("/llm_model/list_config_models",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="列出configs已配置的模型",
|
|
||||||
)(list_config_models)
|
|
||||||
|
|
||||||
app.post("/llm_model/get_model_config",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="获取模型配置(合并后)",
|
|
||||||
)(get_model_config)
|
|
||||||
|
|
||||||
app.post("/llm_model/stop",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="停止指定的LLM模型(Model Worker)",
|
|
||||||
)(stop_llm_model)
|
|
||||||
|
|
||||||
app.post("/llm_model/change",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="切换指定的LLM模型(Model Worker)",
|
|
||||||
)(change_llm_model)
|
|
||||||
|
|
||||||
# 服务器相关接口
|
# 服务器相关接口
|
||||||
app.post("/server/configs",
|
app.post("/server/configs",
|
||||||
tags=["Server State"],
|
tags=["Server State"],
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from server.db.repository import add_message_to_db
|
|||||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
||||||
|
|
||||||
|
|
||||||
def create_models_from_config(configs, callbacks, stream):
|
def create_models_from_config(configs, openai_config, callbacks, stream):
|
||||||
if configs is None:
|
if configs is None:
|
||||||
configs = {}
|
configs = {}
|
||||||
models = {}
|
models = {}
|
||||||
@ -30,6 +30,9 @@ def create_models_from_config(configs, callbacks, stream):
|
|||||||
for model_name, params in model_configs.items():
|
for model_name, params in model_configs.items():
|
||||||
callbacks = callbacks if params.get('callbacks', False) else None
|
callbacks = callbacks if params.get('callbacks', False) else None
|
||||||
model_instance = get_ChatOpenAI(
|
model_instance = get_ChatOpenAI(
|
||||||
|
endpoint_host=openai_config.get('endpoint_host', None),
|
||||||
|
endpoint_host_key=openai_config.get('endpoint_host_key', None),
|
||||||
|
endpoint_host_proxy=openai_config.get('endpoint_host_proxy', None),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=params.get('temperature', 0.5),
|
temperature=params.get('temperature', 0.5),
|
||||||
max_tokens=params.get('max_tokens', 1000),
|
max_tokens=params.get('max_tokens', 1000),
|
||||||
@ -113,6 +116,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
),
|
),
|
||||||
stream: bool = Body(True, description="流式输出"),
|
stream: bool = Body(True, description="流式输出"),
|
||||||
model_config: Dict = Body({}, description="LLM 模型配置"),
|
model_config: Dict = Body({}, description="LLM 模型配置"),
|
||||||
|
openai_config: Dict = Body({}, description="openaiEndpoint配置"),
|
||||||
tool_config: Dict = Body({}, description="工具配置"),
|
tool_config: Dict = Body({}, description="工具配置"),
|
||||||
):
|
):
|
||||||
async def chat_iterator() -> AsyncIterable[str]:
|
async def chat_iterator() -> AsyncIterable[str]:
|
||||||
@ -124,7 +128,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
|
|
||||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, stream=stream)
|
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config,
|
||||||
|
openai_config=openai_config, stream=stream)
|
||||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||||
full_chain = create_models_chains(prompts=prompts,
|
full_chain = create_models_chains(prompts=prompts,
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from server.utils import get_prompt_template
|
|||||||
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||||
|
endpoint_host: str = Body(False, description="接入点地址"),
|
||||||
|
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||||
|
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
@ -24,6 +27,9 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
|
|
||||||
#TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
#TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||||
async def completion_iterator(query: str,
|
async def completion_iterator(query: str,
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
prompt_name: str = prompt_name,
|
prompt_name: str = prompt_name,
|
||||||
echo: bool = echo,
|
echo: bool = echo,
|
||||||
@ -34,6 +40,9 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
max_tokens = None
|
max_tokens = None
|
||||||
|
|
||||||
model = get_OpenAI(
|
model = get_OpenAI(
|
||||||
|
endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -63,7 +72,10 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(completion_iterator(query=query,
|
return StreamingResponse(completion_iterator(query=query,
|
||||||
|
endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,124 +0,0 @@
|
|||||||
from fastapi import Body
|
|
||||||
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, LLM_MODEL_CONFIG
|
|
||||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
|
||||||
get_httpx_client, get_model_worker_config)
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
def list_running_models(
|
|
||||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
|
|
||||||
placeholder: str = Body(None, description="该参数未使用,占位用"),
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
从fastchat controller获取已加载模型列表及其配置项
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
controller_address = controller_address or fschat_controller_address()
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
r = client.post(controller_address + "/list_models")
|
|
||||||
models = r.json()["models"]
|
|
||||||
data = {m: get_model_config(m).data for m in models}
|
|
||||||
|
|
||||||
## 只有LLM模型才返回
|
|
||||||
result = {}
|
|
||||||
for model, config in data.items():
|
|
||||||
if model in LLM_MODEL_CONFIG['llm_model']:
|
|
||||||
result[model] = config
|
|
||||||
|
|
||||||
return BaseResponse(data=result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'{e.__class__.__name__}: {e}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
return BaseResponse(
|
|
||||||
code=500,
|
|
||||||
data={},
|
|
||||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def list_config_models(
|
|
||||||
types: List[str] = Body(["local", "online"], description="模型配置项类别,如local, online, worker"),
|
|
||||||
placeholder: str = Body(None, description="占位用,无实际效果")
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
从本地获取configs中配置的模型列表
|
|
||||||
'''
|
|
||||||
data = {}
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for type, models in list_config_llm_models().items():
|
|
||||||
if type in types:
|
|
||||||
data[type] = {m: get_model_config(m).data for m in models}
|
|
||||||
|
|
||||||
for model, config in data.items():
|
|
||||||
if model in LLM_MODEL_CONFIG['llm_model']:
|
|
||||||
result[type][model] = config
|
|
||||||
|
|
||||||
return BaseResponse(data=result)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(
|
|
||||||
model_name: str = Body(description="配置中LLM模型的名称"),
|
|
||||||
placeholder: str = Body(None, description="占位用,无实际效果")
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
获取LLM模型配置项(合并后的)
|
|
||||||
'''
|
|
||||||
config = {}
|
|
||||||
# 删除ONLINE_MODEL配置中的敏感信息
|
|
||||||
for k, v in get_model_worker_config(model_name=model_name).items():
|
|
||||||
if not (k == "worker_class"
|
|
||||||
or "key" in k.lower()
|
|
||||||
or "secret" in k.lower()
|
|
||||||
or k.lower().endswith("id")):
|
|
||||||
config[k] = v
|
|
||||||
|
|
||||||
return BaseResponse(data=config)
|
|
||||||
|
|
||||||
|
|
||||||
def stop_llm_model(
|
|
||||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]),
|
|
||||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
向fastchat controller请求停止某个LLM模型。
|
|
||||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
controller_address = controller_address or fschat_controller_address()
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
r = client.post(
|
|
||||||
controller_address + "/release_worker",
|
|
||||||
json={"model_name": model_name},
|
|
||||||
)
|
|
||||||
return r.json()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'{e.__class__.__name__}: {e}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
return BaseResponse(
|
|
||||||
code=500,
|
|
||||||
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def change_llm_model(
|
|
||||||
model_name: str = Body(..., description="当前运行模型"),
|
|
||||||
new_model_name: str = Body(..., description="要切换的新模型"),
|
|
||||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
向fastchat controller请求切换LLM模型。
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
controller_address = controller_address or fschat_controller_address()
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
r = client.post(
|
|
||||||
controller_address + "/release_worker",
|
|
||||||
json={"model_name": model_name, "new_model_name": new_model_name},
|
|
||||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
|
||||||
)
|
|
||||||
return r.json()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'{e.__class__.__name__}: {e}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
return BaseResponse(
|
|
||||||
code=500,
|
|
||||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
|
||||||
104
server/utils.py
104
server/utils.py
@ -45,6 +45,9 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|||||||
|
|
||||||
|
|
||||||
def get_ChatOpenAI(
|
def get_ChatOpenAI(
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
@ -61,18 +64,21 @@ def get_ChatOpenAI(
|
|||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
openai_api_key=config.get("api_key", "EMPTY"),
|
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
openai_proxy=config.get("openai_proxy"),
|
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_OpenAI(
|
def get_OpenAI(
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
@ -82,19 +88,18 @@ def get_OpenAI(
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> OpenAI:
|
) -> OpenAI:
|
||||||
config = get_model_worker_config(model_name)
|
|
||||||
if model_name == "openai-api":
|
# TODO: 从API获取模型信息
|
||||||
model_name = config.get("model_name")
|
|
||||||
model = OpenAI(
|
model = OpenAI(
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
openai_api_key=config.get("api_key", "EMPTY"),
|
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
openai_proxy=config.get("openai_proxy"),
|
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
@ -365,69 +370,14 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
'''
|
'''
|
||||||
from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
|
from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
|
||||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||||
from server import model_workers
|
|
||||||
|
|
||||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||||
if model_name in ONLINE_LLM_MODEL:
|
|
||||||
config["online_api"] = True
|
|
||||||
if provider := config.get("provider"):
|
|
||||||
try:
|
|
||||||
config["worker_class"] = getattr(model_workers, provider)
|
|
||||||
except Exception as e:
|
|
||||||
msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
# 本地模型
|
|
||||||
if model_name in MODEL_PATH["llm_model"]:
|
|
||||||
path = get_model_path(model_name)
|
|
||||||
config["model_path"] = path
|
|
||||||
if path and os.path.isdir(path):
|
|
||||||
config["model_path_exists"] = True
|
|
||||||
config["device"] = llm_device(config.get("device"))
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_all_model_worker_configs() -> dict:
|
|
||||||
result = {}
|
|
||||||
model_names = set(FSCHAT_MODEL_WORKERS.keys())
|
|
||||||
for name in model_names:
|
|
||||||
if name != "default":
|
|
||||||
result[name] = get_model_worker_config(name)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def fschat_controller_address() -> str:
|
|
||||||
from configs.server_config import FSCHAT_CONTROLLER
|
|
||||||
|
|
||||||
host = FSCHAT_CONTROLLER["host"]
|
|
||||||
if host == "0.0.0.0":
|
|
||||||
host = "127.0.0.1"
|
|
||||||
port = FSCHAT_CONTROLLER["port"]
|
|
||||||
return f"http://{host}:{port}"
|
|
||||||
|
|
||||||
|
|
||||||
def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str:
|
|
||||||
if model := get_model_worker_config(model_name): # TODO: depends fastchat
|
|
||||||
host = model["host"]
|
|
||||||
if host == "0.0.0.0":
|
|
||||||
host = "127.0.0.1"
|
|
||||||
port = model["port"]
|
|
||||||
return f"http://{host}:{port}"
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def fschat_openai_api_address() -> str:
|
|
||||||
from configs.server_config import FSCHAT_OPENAI_API
|
|
||||||
|
|
||||||
host = FSCHAT_OPENAI_API["host"]
|
|
||||||
if host == "0.0.0.0":
|
|
||||||
host = "127.0.0.1"
|
|
||||||
port = FSCHAT_OPENAI_API["port"]
|
|
||||||
return f"http://{host}:{port}/v1"
|
|
||||||
|
|
||||||
|
|
||||||
def api_address() -> str:
|
def api_address() -> str:
|
||||||
from configs.server_config import API_SERVER
|
from configs.server_config import API_SERVER
|
||||||
|
|
||||||
@ -461,6 +411,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
|
|||||||
def set_httpx_config(
|
def set_httpx_config(
|
||||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||||
proxy: Union[str, Dict] = None,
|
proxy: Union[str, Dict] = None,
|
||||||
|
unused_proxies: List[str] = [],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
|
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
|
||||||
@ -498,11 +449,7 @@ def set_httpx_config(
|
|||||||
"http://localhost",
|
"http://localhost",
|
||||||
]
|
]
|
||||||
# do not use proxy for user deployed fastchat servers
|
# do not use proxy for user deployed fastchat servers
|
||||||
for x in [
|
for x in unused_proxies:
|
||||||
fschat_controller_address(),
|
|
||||||
fschat_model_worker_address(),
|
|
||||||
fschat_openai_api_address(),
|
|
||||||
]:
|
|
||||||
host = ":".join(x.split(":")[:2])
|
host = ":".join(x.split(":")[:2])
|
||||||
if host not in no_proxy:
|
if host not in no_proxy:
|
||||||
no_proxy.append(host)
|
no_proxy.append(host)
|
||||||
@ -568,6 +515,7 @@ def get_httpx_client(
|
|||||||
use_async: bool = False,
|
use_async: bool = False,
|
||||||
proxies: Union[str, Dict] = None,
|
proxies: Union[str, Dict] = None,
|
||||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||||
|
unused_proxies: List[str] = [],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[httpx.Client, httpx.AsyncClient]:
|
) -> Union[httpx.Client, httpx.AsyncClient]:
|
||||||
'''
|
'''
|
||||||
@ -579,11 +527,7 @@ def get_httpx_client(
|
|||||||
"all://localhost": None,
|
"all://localhost": None,
|
||||||
}
|
}
|
||||||
# do not use proxy for user deployed fastchat servers
|
# do not use proxy for user deployed fastchat servers
|
||||||
for x in [
|
for x in unused_proxies:
|
||||||
fschat_controller_address(),
|
|
||||||
fschat_model_worker_address(),
|
|
||||||
fschat_openai_api_address(),
|
|
||||||
]:
|
|
||||||
host = ":".join(x.split(":")[:2])
|
host = ":".join(x.split(":")[:2])
|
||||||
default_proxies.update({host: None})
|
default_proxies.update({host: None})
|
||||||
|
|
||||||
@ -629,8 +573,6 @@ def get_server_configs() -> Dict:
|
|||||||
获取configs中的原始配置项,供前端使用
|
获取configs中的原始配置项,供前端使用
|
||||||
'''
|
'''
|
||||||
_custom = {
|
_custom = {
|
||||||
"controller_address": fschat_controller_address(),
|
|
||||||
"openai_api_address": fschat_openai_api_address(),
|
|
||||||
"api_address": api_address(),
|
"api_address": api_address(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -638,14 +580,8 @@ def get_server_configs() -> Dict:
|
|||||||
|
|
||||||
|
|
||||||
def list_online_embed_models() -> List[str]:
|
def list_online_embed_models() -> List[str]:
|
||||||
from server import model_workers
|
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
for k, v in list_config_llm_models()["online"].items():
|
# TODO: 从在线API获取支持的模型列表
|
||||||
if provider := v.get("provider"):
|
|
||||||
worker_class = getattr(model_workers, provider, None)
|
|
||||||
if worker_class is not None and worker_class.can_embedding():
|
|
||||||
ret.append(k)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
622
startup.py
622
startup.py
@ -25,17 +25,11 @@ from configs import (
|
|||||||
LLM_MODEL_CONFIG,
|
LLM_MODEL_CONFIG,
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
TEXT_SPLITTER_NAME,
|
TEXT_SPLITTER_NAME,
|
||||||
FSCHAT_CONTROLLER,
|
|
||||||
FSCHAT_OPENAI_API,
|
|
||||||
FSCHAT_MODEL_WORKERS,
|
|
||||||
API_SERVER,
|
API_SERVER,
|
||||||
WEBUI_SERVER,
|
WEBUI_SERVER,
|
||||||
HTTPX_DEFAULT_TIMEOUT,
|
HTTPX_DEFAULT_TIMEOUT,
|
||||||
)
|
)
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (FastAPI, embedding_device)
|
||||||
fschat_openai_api_address, get_httpx_client,
|
|
||||||
get_model_worker_config,
|
|
||||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
|
||||||
from server.knowledge_base.migrate import create_tables
|
from server.knowledge_base.migrate import create_tables
|
||||||
import argparse
|
import argparse
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
@ -49,234 +43,6 @@ for model_category in LLM_MODEL_CONFIG.values():
|
|||||||
|
|
||||||
all_model_names_list = list(all_model_names)
|
all_model_names_list = list(all_model_names)
|
||||||
|
|
||||||
def create_controller_app(
|
|
||||||
dispatch_method: str,
|
|
||||||
log_level: str = "INFO",
|
|
||||||
) -> FastAPI:
|
|
||||||
import fastchat.constants
|
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
|
||||||
from fastchat.serve.controller import app, Controller, logger
|
|
||||||
logger.setLevel(log_level)
|
|
||||||
|
|
||||||
controller = Controller(dispatch_method)
|
|
||||||
sys.modules["fastchat.serve.controller"].controller = controller
|
|
||||||
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
app.title = "FastChat Controller"
|
|
||||||
app._controller = controller
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|
||||||
"""
|
|
||||||
kwargs包含的字段如下:
|
|
||||||
host:
|
|
||||||
port:
|
|
||||||
model_names:[`model_name`]
|
|
||||||
controller_address:
|
|
||||||
worker_address:
|
|
||||||
|
|
||||||
对于Langchain支持的模型:
|
|
||||||
langchain_model:True
|
|
||||||
不会使用fschat
|
|
||||||
对于online_api:
|
|
||||||
online_api:True
|
|
||||||
worker_class: `provider`
|
|
||||||
对于离线模型:
|
|
||||||
model_path: `model_name_or_path`,huggingface的repo-id或本地路径
|
|
||||||
device:`LLM_DEVICE`
|
|
||||||
"""
|
|
||||||
import fastchat.constants
|
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
args = parser.parse_args([])
|
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(args, k, v)
|
|
||||||
if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
|
|
||||||
from fastchat.serve.base_model_worker import app
|
|
||||||
worker = ""
|
|
||||||
# 在线模型API
|
|
||||||
elif worker_class := kwargs.get("worker_class"):
|
|
||||||
from fastchat.serve.base_model_worker import app
|
|
||||||
|
|
||||||
worker = worker_class(model_names=args.model_names,
|
|
||||||
controller_addr=args.controller_address,
|
|
||||||
worker_addr=args.worker_address)
|
|
||||||
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
|
||||||
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
|
||||||
|
|
||||||
# 本地模型
|
|
||||||
else:
|
|
||||||
from configs.model_config import VLLM_MODEL_DICT
|
|
||||||
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
|
||||||
import fastchat.serve.vllm_worker
|
|
||||||
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
|
||||||
from vllm import AsyncLLMEngine
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
|
||||||
|
|
||||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
|
||||||
args.tokenizer_mode = 'auto'
|
|
||||||
args.trust_remote_code = True
|
|
||||||
args.download_dir = None
|
|
||||||
args.load_format = 'auto'
|
|
||||||
args.dtype = 'auto'
|
|
||||||
args.seed = 0
|
|
||||||
args.worker_use_ray = False
|
|
||||||
args.pipeline_parallel_size = 1
|
|
||||||
args.tensor_parallel_size = 1
|
|
||||||
args.block_size = 16
|
|
||||||
args.swap_space = 4 # GiB
|
|
||||||
args.gpu_memory_utilization = 0.90
|
|
||||||
args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
|
|
||||||
args.max_num_seqs = 256
|
|
||||||
args.disable_log_stats = False
|
|
||||||
args.conv_template = None
|
|
||||||
args.limit_worker_concurrency = 5
|
|
||||||
args.no_register = False
|
|
||||||
args.num_gpus = 4 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
|
||||||
args.engine_use_ray = False
|
|
||||||
args.disable_log_requests = False
|
|
||||||
|
|
||||||
# 0.2.2 vllm后要加的参数, 但是这里不需要
|
|
||||||
args.max_model_len = None
|
|
||||||
args.revision = None
|
|
||||||
args.quantization = None
|
|
||||||
args.max_log_len = None
|
|
||||||
args.tokenizer_revision = None
|
|
||||||
|
|
||||||
# 0.2.2 vllm需要新加的参数
|
|
||||||
args.max_paddings = 256
|
|
||||||
|
|
||||||
if args.model_path:
|
|
||||||
args.model = args.model_path
|
|
||||||
if args.num_gpus > 1:
|
|
||||||
args.tensor_parallel_size = args.num_gpus
|
|
||||||
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(args, k, v)
|
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
||||||
|
|
||||||
worker = VLLMWorker(
|
|
||||||
controller_addr=args.controller_address,
|
|
||||||
worker_addr=args.worker_address,
|
|
||||||
worker_id=worker_id,
|
|
||||||
model_path=args.model_path,
|
|
||||||
model_names=args.model_names,
|
|
||||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
|
||||||
no_register=args.no_register,
|
|
||||||
llm_engine=engine,
|
|
||||||
conv_template=args.conv_template,
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
|
||||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
|
||||||
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
|
||||||
|
|
||||||
else:
|
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
|
||||||
|
|
||||||
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
|
||||||
args.max_gpu_memory = "22GiB"
|
|
||||||
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
|
||||||
|
|
||||||
args.load_8bit = False
|
|
||||||
args.cpu_offloading = None
|
|
||||||
args.gptq_ckpt = None
|
|
||||||
args.gptq_wbits = 16
|
|
||||||
args.gptq_groupsize = -1
|
|
||||||
args.gptq_act_order = False
|
|
||||||
args.awq_ckpt = None
|
|
||||||
args.awq_wbits = 16
|
|
||||||
args.awq_groupsize = -1
|
|
||||||
args.model_names = [""]
|
|
||||||
args.conv_template = None
|
|
||||||
args.limit_worker_concurrency = 5
|
|
||||||
args.stream_interval = 2
|
|
||||||
args.no_register = False
|
|
||||||
args.embed_in_truncate = False
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
setattr(args, k, v)
|
|
||||||
if args.gpus:
|
|
||||||
if args.num_gpus is None:
|
|
||||||
args.num_gpus = len(args.gpus.split(','))
|
|
||||||
if len(args.gpus.split(",")) < args.num_gpus:
|
|
||||||
raise ValueError(
|
|
||||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
|
||||||
)
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
|
||||||
gptq_config = GptqConfig(
|
|
||||||
ckpt=args.gptq_ckpt or args.model_path,
|
|
||||||
wbits=args.gptq_wbits,
|
|
||||||
groupsize=args.gptq_groupsize,
|
|
||||||
act_order=args.gptq_act_order,
|
|
||||||
)
|
|
||||||
awq_config = AWQConfig(
|
|
||||||
ckpt=args.awq_ckpt or args.model_path,
|
|
||||||
wbits=args.awq_wbits,
|
|
||||||
groupsize=args.awq_groupsize,
|
|
||||||
)
|
|
||||||
worker = ModelWorker(
|
|
||||||
controller_addr=args.controller_address,
|
|
||||||
worker_addr=args.worker_address,
|
|
||||||
worker_id=worker_id,
|
|
||||||
model_path=args.model_path,
|
|
||||||
model_names=args.model_names,
|
|
||||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
|
||||||
no_register=args.no_register,
|
|
||||||
device=args.device,
|
|
||||||
num_gpus=args.num_gpus,
|
|
||||||
max_gpu_memory=args.max_gpu_memory,
|
|
||||||
load_8bit=args.load_8bit,
|
|
||||||
cpu_offloading=args.cpu_offloading,
|
|
||||||
gptq_config=gptq_config,
|
|
||||||
awq_config=awq_config,
|
|
||||||
stream_interval=args.stream_interval,
|
|
||||||
conv_template=args.conv_template,
|
|
||||||
embed_in_truncate=args.embed_in_truncate,
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].args = args
|
|
||||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
|
||||||
# sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
|
|
||||||
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
|
||||||
app._worker = worker
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def create_openai_api_app(
|
|
||||||
controller_address: str,
|
|
||||||
api_keys: List = [],
|
|
||||||
log_level: str = "INFO",
|
|
||||||
) -> FastAPI:
|
|
||||||
import fastchat.constants
|
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
|
||||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
|
||||||
from fastchat.utils import build_logger
|
|
||||||
logger = build_logger("openai_api", "openai_api.log")
|
|
||||||
logger.setLevel(log_level)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
sys.modules["fastchat.serve.openai_api_server"].logger = logger
|
|
||||||
app_settings.controller_address = controller_address
|
|
||||||
app_settings.api_keys = api_keys
|
|
||||||
|
|
||||||
MakeFastAPIOffline(app)
|
|
||||||
app.title = "FastChat OpeanAI API Server"
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
@ -285,154 +51,6 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None):
|
|||||||
started_event.set()
|
started_event.set()
|
||||||
|
|
||||||
|
|
||||||
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
|
||||||
import uvicorn
|
|
||||||
import httpx
|
|
||||||
from fastapi import Body
|
|
||||||
import time
|
|
||||||
import sys
|
|
||||||
from server.utils import set_httpx_config
|
|
||||||
set_httpx_config()
|
|
||||||
|
|
||||||
app = create_controller_app(
|
|
||||||
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
|
|
||||||
log_level=log_level,
|
|
||||||
)
|
|
||||||
_set_app_event(app, started_event)
|
|
||||||
|
|
||||||
# add interface to release and load model worker
|
|
||||||
@app.post("/release_worker")
|
|
||||||
def release_worker(
|
|
||||||
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
|
|
||||||
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
|
|
||||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
|
||||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
|
||||||
) -> Dict:
|
|
||||||
available_models = app._controller.list_models()
|
|
||||||
if new_model_name in available_models:
|
|
||||||
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
|
|
||||||
logger.info(msg)
|
|
||||||
return {"code": 500, "msg": msg}
|
|
||||||
|
|
||||||
if new_model_name:
|
|
||||||
logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}")
|
|
||||||
else:
|
|
||||||
logger.info(f"即将停止LLM模型: {model_name}")
|
|
||||||
|
|
||||||
if model_name not in available_models:
|
|
||||||
msg = f"the model {model_name} is not available"
|
|
||||||
logger.error(msg)
|
|
||||||
return {"code": 500, "msg": msg}
|
|
||||||
|
|
||||||
worker_address = app._controller.get_worker_address(model_name)
|
|
||||||
if not worker_address:
|
|
||||||
msg = f"can not find model_worker address for {model_name}"
|
|
||||||
logger.error(msg)
|
|
||||||
return {"code": 500, "msg": msg}
|
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
|
||||||
r = client.post(worker_address + "/release",
|
|
||||||
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
|
|
||||||
if r.status_code != 200:
|
|
||||||
msg = f"failed to release model: {model_name}"
|
|
||||||
logger.error(msg)
|
|
||||||
return {"code": 500, "msg": msg}
|
|
||||||
|
|
||||||
if new_model_name:
|
|
||||||
timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
|
|
||||||
while timer > 0:
|
|
||||||
models = app._controller.list_models()
|
|
||||||
if new_model_name in models:
|
|
||||||
break
|
|
||||||
time.sleep(1)
|
|
||||||
timer -= 1
|
|
||||||
if timer > 0:
|
|
||||||
msg = f"sucess change model from {model_name} to {new_model_name}"
|
|
||||||
logger.info(msg)
|
|
||||||
return {"code": 200, "msg": msg}
|
|
||||||
else:
|
|
||||||
msg = f"failed change model from {model_name} to {new_model_name}"
|
|
||||||
logger.error(msg)
|
|
||||||
return {"code": 500, "msg": msg}
|
|
||||||
else:
|
|
||||||
msg = f"sucess to release model: {model_name}"
|
|
||||||
logger.info(msg)
|
|
||||||
return {"code": 200, "msg": msg}
|
|
||||||
|
|
||||||
host = FSCHAT_CONTROLLER["host"]
|
|
||||||
port = FSCHAT_CONTROLLER["port"]
|
|
||||||
|
|
||||||
if log_level == "ERROR":
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
|
||||||
|
|
||||||
|
|
||||||
def run_model_worker(
|
|
||||||
model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])),
|
|
||||||
controller_address: str = "",
|
|
||||||
log_level: str = "INFO",
|
|
||||||
q: mp.Queue = None,
|
|
||||||
started_event: mp.Event = None,
|
|
||||||
):
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import Body
|
|
||||||
import sys
|
|
||||||
from server.utils import set_httpx_config
|
|
||||||
set_httpx_config()
|
|
||||||
|
|
||||||
kwargs = get_model_worker_config(model_name)
|
|
||||||
host = kwargs.pop("host")
|
|
||||||
port = kwargs.pop("port")
|
|
||||||
kwargs["model_names"] = [model_name]
|
|
||||||
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
|
||||||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
|
||||||
model_path = kwargs.get("model_path", "")
|
|
||||||
kwargs["model_path"] = model_path
|
|
||||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
|
||||||
_set_app_event(app, started_event)
|
|
||||||
if log_level == "ERROR":
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
# add interface to release and load model
|
|
||||||
@app.post("/release")
|
|
||||||
def release_model(
|
|
||||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
|
||||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
|
||||||
) -> Dict:
|
|
||||||
if keep_origin:
|
|
||||||
if new_model_name:
|
|
||||||
q.put([model_name, "start", new_model_name])
|
|
||||||
else:
|
|
||||||
if new_model_name:
|
|
||||||
q.put([model_name, "replace", new_model_name])
|
|
||||||
else:
|
|
||||||
q.put([model_name, "stop", None])
|
|
||||||
return {"code": 200, "msg": "done"}
|
|
||||||
|
|
||||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
|
||||||
|
|
||||||
|
|
||||||
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
|
|
||||||
import uvicorn
|
|
||||||
import sys
|
|
||||||
from server.utils import set_httpx_config
|
|
||||||
set_httpx_config()
|
|
||||||
|
|
||||||
controller_addr = fschat_controller_address()
|
|
||||||
app = create_openai_api_app(controller_addr, log_level=log_level)
|
|
||||||
_set_app_event(app, started_event)
|
|
||||||
|
|
||||||
host = FSCHAT_OPENAI_API["host"]
|
|
||||||
port = FSCHAT_OPENAI_API["port"]
|
|
||||||
if log_level == "ERROR":
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
uvicorn.run(app, host=host, port=port)
|
|
||||||
|
|
||||||
|
|
||||||
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||||
from server.api import create_app
|
from server.api import create_app
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -473,6 +91,18 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None):
|
|||||||
p.wait()
|
p.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def run_loom(started_event: mp.Event = None):
|
||||||
|
from configs import LOOM_CONFIG
|
||||||
|
|
||||||
|
cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local",
|
||||||
|
"-f", LOOM_CONFIG
|
||||||
|
]
|
||||||
|
|
||||||
|
p = subprocess.Popen(cmd)
|
||||||
|
started_event.set()
|
||||||
|
p.wait()
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.ArgumentParser:
|
def parse_args() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -488,57 +118,14 @@ def parse_args() -> argparse.ArgumentParser:
|
|||||||
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
|
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
|
||||||
dest="all_api",
|
dest="all_api",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--llm-api",
|
|
||||||
action="store_true",
|
|
||||||
help="run fastchat's controller/openai_api/model_worker servers",
|
|
||||||
dest="llm_api",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-o",
|
|
||||||
"--openai-api",
|
|
||||||
action="store_true",
|
|
||||||
help="run fastchat's controller/openai_api servers",
|
|
||||||
dest="openai_api",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-m",
|
|
||||||
"--model-worker",
|
|
||||||
action="store_true",
|
|
||||||
help="run fastchat's model_worker server with specified model name. "
|
|
||||||
"specify --model-name if not using default llm models",
|
|
||||||
dest="model_worker",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-n",
|
|
||||||
"--model-name",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
default=all_model_names_list,
|
|
||||||
help="specify model name for model worker. "
|
|
||||||
"add addition names with space seperated to start multiple model workers.",
|
|
||||||
dest="model_name",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-c",
|
|
||||||
"--controller",
|
|
||||||
type=str,
|
|
||||||
help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
|
|
||||||
dest="controller_address",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--api",
|
"--api",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="run api.py server",
|
help="run api.py server",
|
||||||
dest="api",
|
dest="api",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"-p",
|
|
||||||
"--api-worker",
|
|
||||||
action="store_true",
|
|
||||||
help="run online model api such as zhipuai",
|
|
||||||
dest="api_worker",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-w",
|
"-w",
|
||||||
"--webui",
|
"--webui",
|
||||||
@ -546,13 +133,6 @@ def parse_args() -> argparse.ArgumentParser:
|
|||||||
help="run webui.py server",
|
help="run webui.py server",
|
||||||
dest="webui",
|
dest="webui",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"-q",
|
|
||||||
"--quiet",
|
|
||||||
action="store_true",
|
|
||||||
help="减少fastchat服务log信息",
|
|
||||||
dest="quiet",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-i",
|
"-i",
|
||||||
"--lite",
|
"--lite",
|
||||||
@ -578,24 +158,13 @@ def dump_server_info(after_start=False, args=None):
|
|||||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||||
print("\n")
|
print("\n")
|
||||||
|
|
||||||
models = list(LLM_MODEL_CONFIG['llm_model'].keys())
|
|
||||||
if args and args.model_name:
|
|
||||||
models = args.model_name
|
|
||||||
|
|
||||||
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
||||||
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
|
||||||
|
|
||||||
for model in models:
|
|
||||||
pprint(get_model_worker_config(model))
|
|
||||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||||
|
|
||||||
if after_start:
|
if after_start:
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"服务端运行信息:")
|
print(f"服务端运行信息:")
|
||||||
if args.openai_api:
|
|
||||||
print(f" OpenAI API Server: {fschat_openai_api_address()}")
|
|
||||||
if args.api:
|
|
||||||
print(f" Chatchat API Server: {api_address()}")
|
|
||||||
if args.webui:
|
if args.webui:
|
||||||
print(f" Chatchat WEBUI Server: {webui_address()}")
|
print(f" Chatchat WEBUI Server: {webui_address()}")
|
||||||
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
||||||
@ -625,7 +194,6 @@ async def start_main_server():
|
|||||||
manager = mp.Manager()
|
manager = mp.Manager()
|
||||||
run_mode = None
|
run_mode = None
|
||||||
|
|
||||||
queue = manager.Queue()
|
|
||||||
args, parser = parse_args()
|
args, parser = parse_args()
|
||||||
|
|
||||||
if args.all_webui:
|
if args.all_webui:
|
||||||
@ -662,69 +230,16 @@ async def start_main_server():
|
|||||||
processes = {"online_api": {}, "model_worker": {}}
|
processes = {"online_api": {}, "model_worker": {}}
|
||||||
|
|
||||||
def process_count():
|
def process_count():
|
||||||
return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2
|
return len(processes)
|
||||||
|
|
||||||
if args.quiet or not log_verbose:
|
|
||||||
log_level = "ERROR"
|
|
||||||
else:
|
|
||||||
log_level = "INFO"
|
|
||||||
|
|
||||||
controller_started = manager.Event()
|
|
||||||
if args.openai_api:
|
|
||||||
process = Process(
|
|
||||||
target=run_controller,
|
|
||||||
name=f"controller",
|
|
||||||
kwargs=dict(log_level=log_level, started_event=controller_started),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
processes["controller"] = process
|
|
||||||
|
|
||||||
process = Process(
|
|
||||||
target=run_openai_api,
|
|
||||||
name=f"openai_api",
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
processes["openai_api"] = process
|
|
||||||
|
|
||||||
model_worker_started = []
|
|
||||||
if args.model_worker:
|
|
||||||
for model_name in args.model_name:
|
|
||||||
config = get_model_worker_config(model_name)
|
|
||||||
if not config.get("online_api"):
|
|
||||||
e = manager.Event()
|
|
||||||
model_worker_started.append(e)
|
|
||||||
process = Process(
|
|
||||||
target=run_model_worker,
|
|
||||||
name=f"model_worker - {model_name}",
|
|
||||||
kwargs=dict(model_name=model_name,
|
|
||||||
controller_address=args.controller_address,
|
|
||||||
log_level=log_level,
|
|
||||||
q=queue,
|
|
||||||
started_event=e),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
processes["model_worker"][model_name] = process
|
|
||||||
|
|
||||||
if args.api_worker:
|
|
||||||
for model_name in args.model_name:
|
|
||||||
config = get_model_worker_config(model_name)
|
|
||||||
if (config.get("online_api")
|
|
||||||
and config.get("worker_class")
|
|
||||||
and model_name in FSCHAT_MODEL_WORKERS):
|
|
||||||
e = manager.Event()
|
|
||||||
model_worker_started.append(e)
|
|
||||||
process = Process(
|
|
||||||
target=run_model_worker,
|
|
||||||
name=f"api_worker - {model_name}",
|
|
||||||
kwargs=dict(model_name=model_name,
|
|
||||||
controller_address=args.controller_address,
|
|
||||||
log_level=log_level,
|
|
||||||
q=queue,
|
|
||||||
started_event=e),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
processes["online_api"][model_name] = process
|
|
||||||
|
|
||||||
|
loom_started = manager.Event()
|
||||||
|
process = Process(
|
||||||
|
target=run_loom,
|
||||||
|
name=f"run_loom Server",
|
||||||
|
kwargs=dict(started_event=loom_started),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
processes["run_loom"] = process
|
||||||
api_started = manager.Event()
|
api_started = manager.Event()
|
||||||
if args.api:
|
if args.api:
|
||||||
process = Process(
|
process = Process(
|
||||||
@ -750,25 +265,10 @@ async def start_main_server():
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# 保证任务收到SIGINT后,能够正常退出
|
# 保证任务收到SIGINT后,能够正常退出
|
||||||
if p := processes.get("controller"):
|
if p := processes.get("run_loom"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
controller_started.wait() # 等待controller启动完成
|
loom_started.wait() # 等待Loom启动完成
|
||||||
|
|
||||||
if p := processes.get("openai_api"):
|
|
||||||
p.start()
|
|
||||||
p.name = f"{p.name} ({p.pid})"
|
|
||||||
|
|
||||||
for n, p in processes.get("model_worker", {}).items():
|
|
||||||
p.start()
|
|
||||||
p.name = f"{p.name} ({p.pid})"
|
|
||||||
|
|
||||||
for n, p in processes.get("online_api", []).items():
|
|
||||||
p.start()
|
|
||||||
p.name = f"{p.name} ({p.pid})"
|
|
||||||
|
|
||||||
for e in model_worker_started:
|
|
||||||
e.wait()
|
|
||||||
|
|
||||||
if p := processes.get("api"):
|
if p := processes.get("api"):
|
||||||
p.start()
|
p.start()
|
||||||
@ -782,74 +282,10 @@ async def start_main_server():
|
|||||||
|
|
||||||
dump_server_info(after_start=True, args=args)
|
dump_server_info(after_start=True, args=args)
|
||||||
|
|
||||||
while True:
|
# 等待所有进程退出
|
||||||
cmd = queue.get() # 收到切换模型的消息
|
if p := processes.get("webui"):
|
||||||
e = manager.Event()
|
|
||||||
if isinstance(cmd, list):
|
|
||||||
model_name, cmd, new_model_name = cmd
|
|
||||||
if cmd == "start": # 运行新模型
|
|
||||||
logger.info(f"准备启动新模型进程:{new_model_name}")
|
|
||||||
process = Process(
|
|
||||||
target=run_model_worker,
|
|
||||||
name=f"model_worker - {new_model_name}",
|
|
||||||
kwargs=dict(model_name=new_model_name,
|
|
||||||
controller_address=args.controller_address,
|
|
||||||
log_level=log_level,
|
|
||||||
q=queue,
|
|
||||||
started_event=e),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
process.start()
|
|
||||||
process.name = f"{process.name} ({process.pid})"
|
|
||||||
processes["model_worker"][new_model_name] = process
|
|
||||||
e.wait()
|
|
||||||
logger.info(f"成功启动新模型进程:{new_model_name}")
|
|
||||||
elif cmd == "stop":
|
|
||||||
if process := processes["model_worker"].get(model_name):
|
|
||||||
time.sleep(1)
|
|
||||||
process.terminate()
|
|
||||||
process.join()
|
|
||||||
logger.info(f"停止模型进程:{model_name}")
|
|
||||||
else:
|
|
||||||
logger.error(f"未找到模型进程:{model_name}")
|
|
||||||
elif cmd == "replace":
|
|
||||||
if process := processes["model_worker"].pop(model_name, None):
|
|
||||||
logger.info(f"停止模型进程:{model_name}")
|
|
||||||
start_time = datetime.now()
|
|
||||||
time.sleep(1)
|
|
||||||
process.terminate()
|
|
||||||
process.join()
|
|
||||||
process = Process(
|
|
||||||
target=run_model_worker,
|
|
||||||
name=f"model_worker - {new_model_name}",
|
|
||||||
kwargs=dict(model_name=new_model_name,
|
|
||||||
controller_address=args.controller_address,
|
|
||||||
log_level=log_level,
|
|
||||||
q=queue,
|
|
||||||
started_event=e),
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
process.start()
|
|
||||||
process.name = f"{process.name} ({process.pid})"
|
|
||||||
processes["model_worker"][new_model_name] = process
|
|
||||||
e.wait()
|
|
||||||
timing = datetime.now() - start_time
|
|
||||||
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。")
|
|
||||||
else:
|
|
||||||
logger.error(f"未找到模型进程:{model_name}")
|
|
||||||
|
|
||||||
# for process in processes.get("model_worker", {}).values():
|
p.join()
|
||||||
# process.join()
|
|
||||||
# for process in processes.get("online_api", {}).values():
|
|
||||||
# process.join()
|
|
||||||
|
|
||||||
# for name, process in processes.items():
|
|
||||||
# if name not in ["model_worker", "online_api"]:
|
|
||||||
# if isinstance(p, dict):
|
|
||||||
# for work_process in p.values():
|
|
||||||
# work_process.join()
|
|
||||||
# else:
|
|
||||||
# process.join()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
||||||
|
|||||||
@ -13,21 +13,6 @@ from typing import List
|
|||||||
api = ApiRequest()
|
api = ApiRequest()
|
||||||
|
|
||||||
|
|
||||||
def test_get_default_llm():
|
|
||||||
llm = api.get_default_llm_model()
|
|
||||||
|
|
||||||
print(llm)
|
|
||||||
assert isinstance(llm, tuple)
|
|
||||||
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_configs():
|
|
||||||
configs = api.get_server_configs()
|
|
||||||
pprint(configs, depth=2)
|
|
||||||
|
|
||||||
assert isinstance(configs, dict)
|
|
||||||
assert len(configs) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type", ["llm_chat"])
|
@pytest.mark.parametrize("type", ["llm_chat"])
|
||||||
def test_get_prompt_template(type):
|
def test_get_prompt_template(type):
|
||||||
|
|||||||
@ -8,13 +8,13 @@ from server.model_workers.base import *
|
|||||||
from server.utils import get_model_worker_config, list_config_llm_models
|
from server.utils import get_model_worker_config, list_config_llm_models
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import pytest
|
import pytest
|
||||||
|
#
|
||||||
|
#
|
||||||
workers = []
|
# workers = []
|
||||||
for x in list_config_llm_models()["online"]:
|
# for x in list_config_llm_models()["online"]:
|
||||||
if x in ONLINE_LLM_MODEL and x not in workers:
|
# if x in ONLINE_LLM_MODEL and x not in workers:
|
||||||
workers.append(x)
|
# workers.append(x)
|
||||||
print(f"all workers to test: {workers}")
|
# print(f"all workers to test: {workers}")
|
||||||
|
|
||||||
# workers = ["fangzhou-api"]
|
# workers = ["fangzhou-api"]
|
||||||
|
|
||||||
|
|||||||
25
webui.py
25
webui.py
@ -1,4 +1,6 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
from webui_pages.openai_plugins import openai_plugins_page
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from streamlit_option_menu import option_menu
|
from streamlit_option_menu import option_menu
|
||||||
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||||
@ -22,9 +24,26 @@ if __name__ == "__main__":
|
|||||||
'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat',
|
'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat',
|
||||||
'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues",
|
'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues",
|
||||||
'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!"""
|
'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!"""
|
||||||
}
|
},
|
||||||
|
layout="wide"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# use the following code to set the app to wide mode and the html markdown to increase the sidebar width
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
<style>
|
||||||
|
[data-testid="stSidebar"][aria-expanded="true"] > div:first-child{
|
||||||
|
width: 350px;
|
||||||
|
}
|
||||||
|
[data-testid="stSidebar"][aria-expanded="false"] > div:first-child{
|
||||||
|
width: 600px;
|
||||||
|
margin-left: -600px;
|
||||||
|
}
|
||||||
|
|
||||||
|
""",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
pages = {
|
pages = {
|
||||||
"对话": {
|
"对话": {
|
||||||
"icon": "chat",
|
"icon": "chat",
|
||||||
@ -34,6 +53,10 @@ if __name__ == "__main__":
|
|||||||
"icon": "hdd-stack",
|
"icon": "hdd-stack",
|
||||||
"func": knowledge_base_page,
|
"func": knowledge_base_page,
|
||||||
},
|
},
|
||||||
|
"模型服务": {
|
||||||
|
"icon": "hdd-stack",
|
||||||
|
"func": openai_plugins_page,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
import base64
|
import base64
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
from streamlit_antd_components.utils import ParseItems
|
||||||
|
|
||||||
from webui_pages.dialogue.utils import process_files
|
from webui_pages.dialogue.utils import process_files
|
||||||
|
from webui_pages.loom_view_client import build_plugins_name, find_menu_items_by_index, set_llm_select, \
|
||||||
|
get_select_model_endpoint
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from streamlit_chatbox import *
|
from streamlit_chatbox import *
|
||||||
from streamlit_modal import Modal
|
from streamlit_modal import Modal
|
||||||
@ -10,12 +13,15 @@ from datetime import datetime
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
|
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG, OPENAI_KEY, OPENAI_PROXY)
|
||||||
from server.callback_handler.agent_callback_handler import AgentStatus
|
from server.callback_handler.agent_callback_handler import AgentStatus
|
||||||
from server.utils import MsgType
|
from server.utils import MsgType
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
import streamlit_antd_components as sac
|
||||||
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
"img",
|
"img",
|
||||||
@ -105,14 +111,8 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
st.session_state.setdefault("conversation_ids", {})
|
st.session_state.setdefault("conversation_ids", {})
|
||||||
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
||||||
st.session_state.setdefault("file_chat_id", None)
|
st.session_state.setdefault("file_chat_id", None)
|
||||||
default_model = api.get_default_llm_model()[0]
|
st.session_state.setdefault("select_plugins_info", None)
|
||||||
|
st.session_state.setdefault("select_model_worker", None)
|
||||||
if not chat_box.chat_inited:
|
|
||||||
st.toast(
|
|
||||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
|
||||||
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
|
||||||
)
|
|
||||||
chat_box.init_session()
|
|
||||||
|
|
||||||
# 弹出自定义命令帮助信息
|
# 弹出自定义命令帮助信息
|
||||||
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
||||||
@ -131,57 +131,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
chat_box.use_chat_name(conversation_name)
|
chat_box.use_chat_name(conversation_name)
|
||||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||||
|
|
||||||
# def on_mode_change():
|
with st.expander("模型选择"):
|
||||||
# mode = st.session_state.dialogue_mode
|
plugins_menu = build_plugins_name()
|
||||||
# text = f"已切换到 {mode} 模式。"
|
|
||||||
# st.toast(text)
|
|
||||||
|
|
||||||
# dialogue_modes = ["智能对话",
|
items, _ = ParseItems(plugins_menu).multi()
|
||||||
# "文件对话",
|
|
||||||
# ]
|
|
||||||
# dialogue_mode = st.selectbox("请选择对话模式:",
|
|
||||||
# dialogue_modes,
|
|
||||||
# index=0,
|
|
||||||
# on_change=on_mode_change,
|
|
||||||
# key="dialogue_mode",
|
|
||||||
# )
|
|
||||||
|
|
||||||
def on_llm_change():
|
if len(plugins_menu) > 0:
|
||||||
if llm_model:
|
|
||||||
config = api.get_model_config(llm_model)
|
|
||||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
|
||||||
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
|
||||||
|
|
||||||
def llm_model_format_func(x):
|
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True)
|
||||||
if x in running_models:
|
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
|
||||||
return f"{x} (Running)"
|
set_llm_select(plugins_info, llm_model_worker)
|
||||||
return x
|
else:
|
||||||
|
st.info("没有可用的插件")
|
||||||
running_models = list(api.list_running_models())
|
|
||||||
available_models = []
|
|
||||||
config_models = api.list_config_models()
|
|
||||||
if not is_lite:
|
|
||||||
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
|
|
||||||
if (v.get("model_path_exists")
|
|
||||||
and k not in running_models):
|
|
||||||
available_models.append(k)
|
|
||||||
for k, v in config_models.get("online", {}).items():
|
|
||||||
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
|
||||||
available_models.append(k)
|
|
||||||
llm_models = running_models + available_models
|
|
||||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
|
||||||
if cur_llm_model in llm_models:
|
|
||||||
index = llm_models.index(cur_llm_model)
|
|
||||||
else:
|
|
||||||
index = 0
|
|
||||||
llm_model = st.selectbox("选择LLM模型",
|
|
||||||
llm_models,
|
|
||||||
index,
|
|
||||||
format_func=llm_model_format_func,
|
|
||||||
on_change=on_llm_change,
|
|
||||||
key="llm_model",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 传入后端的内容
|
# 传入后端的内容
|
||||||
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||||
@ -213,8 +174,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
if is_selected:
|
if is_selected:
|
||||||
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
||||||
|
|
||||||
|
llm_model = None
|
||||||
|
if st.session_state["select_model_worker"] is not None:
|
||||||
|
llm_model = st.session_state["select_model_worker"]['label']
|
||||||
|
|
||||||
if llm_model is not None:
|
if llm_model is not None:
|
||||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||||
|
|
||||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||||
@ -258,7 +223,8 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
history = get_messages_history(
|
history = get_messages_history(
|
||||||
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
model_config["llm_model"].get(next(iter(model_config["llm_model"])), {}).get("history_len", 1)
|
||||||
|
)
|
||||||
chat_box.user_say(prompt)
|
chat_box.user_say(prompt)
|
||||||
if files_upload:
|
if files_upload:
|
||||||
if files_upload["images"]:
|
if files_upload["images"]:
|
||||||
@ -277,10 +243,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
text = ""
|
text = ""
|
||||||
text_action = ""
|
text_action = ""
|
||||||
element_index = 0
|
element_index = 0
|
||||||
|
|
||||||
|
openai_config = {}
|
||||||
|
endpoint_host, select_model_name = get_select_model_endpoint()
|
||||||
|
openai_config["endpoint_host"] = endpoint_host
|
||||||
|
openai_config["model_name"] = select_model_name
|
||||||
|
openai_config["endpoint_host_key"] = OPENAI_KEY
|
||||||
|
openai_config["endpoint_host_proxy"] = OPENAI_PROXY
|
||||||
for d in api.chat_chat(query=prompt,
|
for d in api.chat_chat(query=prompt,
|
||||||
metadata=files_upload,
|
metadata=files_upload,
|
||||||
history=history,
|
history=history,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
openai_config=openai_config,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
tool_config=selected_tool_configs,
|
tool_config=selected_tool_configs,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import base64
|
|||||||
import os
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
def encode_file_to_base64(file):
|
def encode_file_to_base64(file):
|
||||||
# 将文件内容转换为 Base64 编码
|
# 将文件内容转换为 Base64 编码
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
|
|||||||
138
webui_pages/loom_view_client.py
Normal file
138
webui_pages/loom_view_client.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
from typing import Tuple, Any
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient
|
||||||
|
|
||||||
|
client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False)
|
||||||
|
|
||||||
|
|
||||||
|
def update_store():
|
||||||
|
st.session_state.status = client.status()
|
||||||
|
list_plugins = client.list_plugins()
|
||||||
|
st.session_state.run_plugins_list = list_plugins.get("plugins_list", [])
|
||||||
|
launch_subscribe_info = {}
|
||||||
|
for plugin_name in st.session_state.run_plugins_list:
|
||||||
|
launch_subscribe_info[plugin_name] = client.launch_subscribe_info(plugin_name)
|
||||||
|
|
||||||
|
st.session_state.launch_subscribe_info = launch_subscribe_info
|
||||||
|
list_running_models = {}
|
||||||
|
for plugin_name in st.session_state.run_plugins_list:
|
||||||
|
list_running_models[plugin_name] = client.list_running_models(plugin_name)
|
||||||
|
st.session_state.list_running_models = list_running_models
|
||||||
|
|
||||||
|
model_config = {}
|
||||||
|
for plugin_name in st.session_state.run_plugins_list:
|
||||||
|
model_config[plugin_name] = client.list_llm_models(plugin_name)
|
||||||
|
st.session_state.model_config = model_config
|
||||||
|
|
||||||
|
|
||||||
|
def start_plugin():
|
||||||
|
import time
|
||||||
|
start_plugins_name = st.session_state.plugins_name
|
||||||
|
if start_plugins_name in st.session_state.run_plugins_list:
|
||||||
|
st.toast(start_plugins_name + " has already been counted.")
|
||||||
|
|
||||||
|
time.sleep(.5)
|
||||||
|
else:
|
||||||
|
|
||||||
|
st.toast("start_plugin " + start_plugins_name + ",starting.")
|
||||||
|
result = client.launch_subscribe(start_plugins_name)
|
||||||
|
st.toast("start_plugin "+start_plugins_name + " ." + result.get("detail", ""))
|
||||||
|
time.sleep(3)
|
||||||
|
result1 = client.launch_subscribe_start(start_plugins_name)
|
||||||
|
|
||||||
|
st.toast("start_plugin "+start_plugins_name + " ." + result1.get("detail", ""))
|
||||||
|
time.sleep(2)
|
||||||
|
update_store()
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker():
|
||||||
|
select_plugins_name = st.session_state.plugins_name
|
||||||
|
select_worker_id = st.session_state.worker_id
|
||||||
|
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
|
||||||
|
already_counted = False
|
||||||
|
for model in start_model_list:
|
||||||
|
if model['worker_id'] == select_worker_id:
|
||||||
|
already_counted = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if already_counted:
|
||||||
|
st.toast(
|
||||||
|
"select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has already been counted.")
|
||||||
|
import time
|
||||||
|
time.sleep(.5)
|
||||||
|
else:
|
||||||
|
|
||||||
|
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " starting.")
|
||||||
|
result = client.launch_subscribe_start_model(select_plugins_name, select_worker_id)
|
||||||
|
st.toast("start worker_id " + select_worker_id + " ." + result.get("detail", ""))
|
||||||
|
import time
|
||||||
|
time.sleep(.5)
|
||||||
|
update_store()
|
||||||
|
|
||||||
|
|
||||||
|
def stop_worker():
|
||||||
|
select_plugins_name = st.session_state.plugins_name
|
||||||
|
select_worker_id = st.session_state.worker_id
|
||||||
|
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
|
||||||
|
already_counted = False
|
||||||
|
for model in start_model_list:
|
||||||
|
if model['worker_id'] == select_worker_id:
|
||||||
|
already_counted = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not already_counted:
|
||||||
|
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has bad already")
|
||||||
|
import time
|
||||||
|
time.sleep(.5)
|
||||||
|
else:
|
||||||
|
|
||||||
|
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " stopping.")
|
||||||
|
result = client.launch_subscribe_stop_model(select_plugins_name, select_worker_id)
|
||||||
|
st.toast("stop worker_id " + select_worker_id + " ." + result.get("detail", ""))
|
||||||
|
import time
|
||||||
|
time.sleep(.5)
|
||||||
|
update_store()
|
||||||
|
|
||||||
|
|
||||||
|
def build_plugins_name():
|
||||||
|
import streamlit_antd_components as sac
|
||||||
|
if "run_plugins_list" not in st.session_state:
|
||||||
|
return []
|
||||||
|
# 按照模型构建sac.menu(菜单
|
||||||
|
menu_items = []
|
||||||
|
for key, value in st.session_state.list_running_models.items():
|
||||||
|
menu_item_children = []
|
||||||
|
for model in value:
|
||||||
|
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
|
||||||
|
|
||||||
|
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
|
||||||
|
|
||||||
|
return menu_items
|
||||||
|
|
||||||
|
|
||||||
|
def find_menu_items_by_index(menu_items, key):
|
||||||
|
for menu_item in menu_items:
|
||||||
|
if menu_item.get('children') is not None:
|
||||||
|
for child in menu_item.get('children'):
|
||||||
|
if child.get('key') == key:
|
||||||
|
return menu_item, child
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def set_llm_select(plugins_info, llm_model_worker):
|
||||||
|
st.session_state["select_plugins_info"] = plugins_info
|
||||||
|
st.session_state["select_model_worker"] = llm_model_worker
|
||||||
|
|
||||||
|
|
||||||
|
def get_select_model_endpoint() -> Tuple[str, str]:
|
||||||
|
plugins_info = st.session_state["select_plugins_info"]
|
||||||
|
llm_model_worker = st.session_state["select_model_worker"]
|
||||||
|
if plugins_info is None or llm_model_worker is None:
|
||||||
|
raise ValueError("plugins_info or llm_model_worker is None")
|
||||||
|
plugins_name = st.session_state["select_plugins_info"]['label']
|
||||||
|
select_model_name = st.session_state["select_model_worker"]['label']
|
||||||
|
adapter_description = st.session_state.launch_subscribe_info[plugins_name]
|
||||||
|
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
|
||||||
|
return endpoint_host, select_model_name
|
||||||
1
webui_pages/openai_plugins/__init__.py
Normal file
1
webui_pages/openai_plugins/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .base import openai_plugins_page
|
||||||
75
webui_pages/openai_plugins/base.py
Normal file
75
webui_pages/openai_plugins/base.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import streamlit as st
|
||||||
|
from loom_openai_plugins_frontend import loom_openai_plugins_frontend
|
||||||
|
|
||||||
|
from webui_pages.utils import ApiRequest
|
||||||
|
from webui_pages.loom_view_client import (
|
||||||
|
update_store,
|
||||||
|
start_plugin,
|
||||||
|
start_worker,
|
||||||
|
stop_worker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def openai_plugins_page(api: ApiRequest, is_lite: bool = None):
|
||||||
|
if "status" not in st.session_state \
|
||||||
|
or "run_plugins_list" not in st.session_state \
|
||||||
|
or "launch_subscribe_info" not in st.session_state \
|
||||||
|
or "list_running_models" not in st.session_state \
|
||||||
|
or "model_config" not in st.session_state:
|
||||||
|
update_store()
|
||||||
|
|
||||||
|
with (st.container()):
|
||||||
|
|
||||||
|
if "worker_id" not in st.session_state:
|
||||||
|
st.session_state.worker_id = ''
|
||||||
|
if "plugins_name" not in st.session_state and "status" in st.session_state:
|
||||||
|
|
||||||
|
for k, v in st.session_state.status.get("status", {}).get("subscribe_status", []).items():
|
||||||
|
st.session_state.plugins_name = v.get("plugins_names", [])[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
col1, col2 = st.columns([0.8, 0.2])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
event = loom_openai_plugins_frontend(plugins_status=st.session_state.status,
|
||||||
|
run_list_plugins=st.session_state.run_plugins_list,
|
||||||
|
launch_subscribe_info=st.session_state.launch_subscribe_info,
|
||||||
|
list_running_models=st.session_state.list_running_models,
|
||||||
|
model_config=st.session_state.model_config)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
st.write("操作")
|
||||||
|
if not st.session_state.run_plugins_list:
|
||||||
|
button_type_disabled = False
|
||||||
|
button_start_text = '启动'
|
||||||
|
else:
|
||||||
|
button_type_disabled = True
|
||||||
|
button_start_text = '已启动'
|
||||||
|
|
||||||
|
if event:
|
||||||
|
event_type = event.get("event")
|
||||||
|
if event_type == "BottomNavigationAction":
|
||||||
|
st.session_state.plugins_name = event.get("data")
|
||||||
|
st.session_state.worker_id = ''
|
||||||
|
# 不存在run_plugins_list,打开启动按钮
|
||||||
|
if st.session_state.plugins_name not in st.session_state.run_plugins_list \
|
||||||
|
or st.session_state.run_plugins_list:
|
||||||
|
button_type_disabled = False
|
||||||
|
button_start_text = '启动'
|
||||||
|
else:
|
||||||
|
button_type_disabled = True
|
||||||
|
button_start_text = '已启动'
|
||||||
|
if event_type == "CardCoverComponent":
|
||||||
|
st.session_state.worker_id = event.get("data")
|
||||||
|
|
||||||
|
st.button(button_start_text, disabled=button_type_disabled, key="start",
|
||||||
|
on_click=start_plugin)
|
||||||
|
|
||||||
|
if st.session_state.worker_id != '':
|
||||||
|
st.button("启动" + st.session_state.worker_id, key="start_worker",
|
||||||
|
on_click=start_worker)
|
||||||
|
st.button("停止" + st.session_state.worker_id, key="stop_worker",
|
||||||
|
on_click=stop_worker)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -266,6 +266,7 @@ class ApiRequest:
|
|||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model_config: Dict = None,
|
model_config: Dict = None,
|
||||||
|
openai_config: Dict = None,
|
||||||
tool_config: Dict = None,
|
tool_config: Dict = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -280,6 +281,7 @@ class ApiRequest:
|
|||||||
"history": history,
|
"history": history,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"model_config": model_config,
|
"model_config": model_config,
|
||||||
|
"openai_config": openai_config,
|
||||||
"tool_config": tool_config,
|
"tool_config": tool_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -619,219 +621,6 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
# LLM模型相关操作
|
|
||||||
def list_running_models(
|
|
||||||
self,
|
|
||||||
controller_address: str = None,
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
获取Fastchat中正运行的模型列表
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"controller_address": controller_address,
|
|
||||||
}
|
|
||||||
|
|
||||||
if log_verbose:
|
|
||||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
|
||||||
|
|
||||||
response = self.post(
|
|
||||||
"/llm_model/list_running_models",
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
|
|
||||||
|
|
||||||
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
|
||||||
'''
|
|
||||||
从服务器上获取当前运行的LLM模型。
|
|
||||||
当 local_first=True 时,优先返回运行中的本地模型,否则优先按 LLM_MODEL_CONFIG['llm_model']配置顺序返回。
|
|
||||||
返回类型为(model_name, is_local_model)
|
|
||||||
'''
|
|
||||||
|
|
||||||
def ret_sync():
|
|
||||||
running_models = self.list_running_models()
|
|
||||||
if not running_models:
|
|
||||||
return "", False
|
|
||||||
|
|
||||||
model = ""
|
|
||||||
for m in LLM_MODEL_CONFIG['llm_model']:
|
|
||||||
if m not in running_models:
|
|
||||||
continue
|
|
||||||
is_local = not running_models[m].get("online_api")
|
|
||||||
if local_first and not is_local:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
model = m
|
|
||||||
break
|
|
||||||
|
|
||||||
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
|
|
||||||
model = list(running_models)[0]
|
|
||||||
is_local = not running_models[model].get("online_api")
|
|
||||||
return model, is_local
|
|
||||||
|
|
||||||
async def ret_async():
|
|
||||||
running_models = await self.list_running_models()
|
|
||||||
if not running_models:
|
|
||||||
return "", False
|
|
||||||
|
|
||||||
model = ""
|
|
||||||
for m in LLM_MODEL_CONFIG['llm_model']:
|
|
||||||
if m not in running_models:
|
|
||||||
continue
|
|
||||||
is_local = not running_models[m].get("online_api")
|
|
||||||
if local_first and not is_local:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
model = m
|
|
||||||
break
|
|
||||||
|
|
||||||
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
|
|
||||||
model = list(running_models)[0]
|
|
||||||
is_local = not running_models[model].get("online_api")
|
|
||||||
return model, is_local
|
|
||||||
|
|
||||||
if self._use_async:
|
|
||||||
return ret_async()
|
|
||||||
else:
|
|
||||||
return ret_sync()
|
|
||||||
|
|
||||||
def list_config_models(
|
|
||||||
self,
|
|
||||||
types: List[str] = ["local", "online"],
|
|
||||||
) -> Dict[str, Dict]:
|
|
||||||
'''
|
|
||||||
获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"types": types,
|
|
||||||
}
|
|
||||||
response = self.post(
|
|
||||||
"/llm_model/list_config_models",
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
|
||||||
|
|
||||||
def get_model_config(
|
|
||||||
self,
|
|
||||||
model_name: str = None,
|
|
||||||
) -> Dict:
|
|
||||||
'''
|
|
||||||
获取服务器上模型配置
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"model_name": model_name,
|
|
||||||
}
|
|
||||||
response = self.post(
|
|
||||||
"/llm_model/get_model_config",
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
|
||||||
|
|
||||||
def stop_llm_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
controller_address: str = None,
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
停止某个LLM模型。
|
|
||||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"model_name": model_name,
|
|
||||||
"controller_address": controller_address,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.post(
|
|
||||||
"/llm_model/stop",
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
return self._get_response_value(response, as_json=True)
|
|
||||||
|
|
||||||
def change_llm_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
new_model_name: str,
|
|
||||||
controller_address: str = None,
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
向fastchat controller请求切换LLM模型。
|
|
||||||
'''
|
|
||||||
if not model_name or not new_model_name:
|
|
||||||
return {
|
|
||||||
"code": 500,
|
|
||||||
"msg": f"未指定模型名称"
|
|
||||||
}
|
|
||||||
|
|
||||||
def ret_sync():
|
|
||||||
running_models = self.list_running_models()
|
|
||||||
if new_model_name == model_name or new_model_name in running_models:
|
|
||||||
return {
|
|
||||||
"code": 200,
|
|
||||||
"msg": "无需切换"
|
|
||||||
}
|
|
||||||
|
|
||||||
if model_name not in running_models:
|
|
||||||
return {
|
|
||||||
"code": 500,
|
|
||||||
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
|
|
||||||
}
|
|
||||||
|
|
||||||
config_models = self.list_config_models()
|
|
||||||
if new_model_name not in config_models.get("local", {}):
|
|
||||||
return {
|
|
||||||
"code": 500,
|
|
||||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model_name": model_name,
|
|
||||||
"new_model_name": new_model_name,
|
|
||||||
"controller_address": controller_address,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.post(
|
|
||||||
"/llm_model/change",
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
return self._get_response_value(response, as_json=True)
|
|
||||||
|
|
||||||
async def ret_async():
|
|
||||||
running_models = await self.list_running_models()
|
|
||||||
if new_model_name == model_name or new_model_name in running_models:
|
|
||||||
return {
|
|
||||||
"code": 200,
|
|
||||||
"msg": "无需切换"
|
|
||||||
}
|
|
||||||
|
|
||||||
if model_name not in running_models:
|
|
||||||
return {
|
|
||||||
"code": 500,
|
|
||||||
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
|
|
||||||
}
|
|
||||||
|
|
||||||
config_models = await self.list_config_models()
|
|
||||||
if new_model_name not in config_models.get("local", {}):
|
|
||||||
return {
|
|
||||||
"code": 500,
|
|
||||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model_name": model_name,
|
|
||||||
"new_model_name": new_model_name,
|
|
||||||
"controller_address": controller_address,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.post(
|
|
||||||
"/llm_model/change",
|
|
||||||
json=data,
|
|
||||||
)
|
|
||||||
return self._get_response_value(response, as_json=True)
|
|
||||||
|
|
||||||
if self._use_async:
|
|
||||||
return ret_async()
|
|
||||||
else:
|
|
||||||
return ret_sync()
|
|
||||||
|
|
||||||
def embed_texts(
|
def embed_texts(
|
||||||
self,
|
self,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user