集成openai plugins插件

This commit is contained in:
glide-the 2024-01-18 01:11:58 +08:00 committed by liunux4odoo
parent 8063aab7a1
commit 48fb6b83fd
16 changed files with 353 additions and 1130 deletions

View File

@ -8,7 +8,7 @@ torchaudio>=2.1.2
# Langchain 0.1.x requirements
langchain>=0.1.0
langchain_openai>=0.0.2
langchain-community>=1.0.0
langchain-community>=0.0.11
langchainhub>=0.1.14
pydantic==1.10.13

View File

@ -17,9 +17,7 @@ from server.chat.chat import chat
from server.chat.completion import completion
from server.chat.feedback import chat_feedback
from server.embeddings_api import embed_texts_endpoint
from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model,
get_model_config)
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
get_server_configs, get_prompt_template)
from typing import List, Literal
@ -73,32 +71,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
# 摘要相关接口
mount_filename_summary_routes(app)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/get_model_config",
tags=["LLM Model Management"],
summary="获取模型配置(合并后)",
)(get_model_config)
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)(stop_llm_model)
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)(change_llm_model)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],

View File

@ -21,7 +21,7 @@ from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
def create_models_from_config(configs, callbacks, stream):
def create_models_from_config(configs, openai_config, callbacks, stream):
if configs is None:
configs = {}
models = {}
@ -30,6 +30,9 @@ def create_models_from_config(configs, callbacks, stream):
for model_name, params in model_configs.items():
callbacks = callbacks if params.get('callbacks', False) else None
model_instance = get_ChatOpenAI(
endpoint_host=openai_config.get('endpoint_host', None),
endpoint_host_key=openai_config.get('endpoint_host_key', None),
endpoint_host_proxy=openai_config.get('endpoint_host_proxy', None),
model_name=model_name,
temperature=params.get('temperature', 0.5),
max_tokens=params.get('max_tokens', 1000),
@ -113,6 +116,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
),
stream: bool = Body(True, description="流式输出"),
model_config: Dict = Body({}, description="LLM 模型配置"),
openai_config: Dict = Body({}, description="openaiEndpoint配置"),
tool_config: Dict = Body({}, description="工具配置"),
):
async def chat_iterator() -> AsyncIterable[str]:
@ -124,7 +128,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, stream=stream)
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config,
openai_config=openai_config, stream=stream)
tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
full_chain = create_models_chains(prompts=prompts,

View File

@ -14,6 +14,9 @@ from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
@ -24,6 +27,9 @@ async def completion(query: str = Body(..., description="用户输入", examples
#TODO: 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
model_name: str = None,
prompt_name: str = prompt_name,
echo: bool = echo,
@ -34,6 +40,9 @@ async def completion(query: str = Body(..., description="用户输入", examples
max_tokens = None
model = get_OpenAI(
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
@ -63,7 +72,10 @@ async def completion(query: str = Body(..., description="用户输入", examples
await task
return EventSourceResponse(completion_iterator(query=query,
return StreamingResponse(completion_iterator(query=query,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
model_name=model_name,
prompt_name=prompt_name),
)

View File

@ -1,124 +0,0 @@
from fastapi import Body
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, LLM_MODEL_CONFIG
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
from typing import List
def list_running_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
placeholder: str = Body(None, description="该参数未使用,占位用"),
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表及其配置项
'''
try:
controller_address = controller_address or fschat_controller_address()
with get_httpx_client() as client:
r = client.post(controller_address + "/list_models")
models = r.json()["models"]
data = {m: get_model_config(m).data for m in models}
## 只有LLM模型才返回
result = {}
for model, config in data.items():
if model in LLM_MODEL_CONFIG['llm_model']:
result[model] = config
return BaseResponse(data=result)
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
data={},
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
def list_config_models(
types: List[str] = Body(["local", "online"], description="模型配置项类别如local, online, worker"),
placeholder: str = Body(None, description="占位用,无实际效果")
) -> BaseResponse:
'''
从本地获取configs中配置的模型列表
'''
data = {}
result = {}
for type, models in list_config_llm_models().items():
if type in types:
data[type] = {m: get_model_config(m).data for m in models}
for model, config in data.items():
if model in LLM_MODEL_CONFIG['llm_model']:
result[type][model] = config
return BaseResponse(data=result)
def get_model_config(
model_name: str = Body(description="配置中LLM模型的名称"),
placeholder: str = Body(None, description="占位用,无实际效果")
) -> BaseResponse:
'''
获取LLM模型配置项合并后的
'''
config = {}
# 删除ONLINE_MODEL配置中的敏感信息
for k, v in get_model_worker_config(model_name=model_name).items():
if not (k == "worker_class"
or "key" in k.lower()
or "secret" in k.lower()
or k.lower().endswith("id")):
config[k] = v
return BaseResponse(data=config)
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
with get_httpx_client() as client:
r = client.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
def change_llm_model(
model_name: str = Body(..., description="当前运行模型"),
new_model_name: str = Body(..., description="要切换的新模型"),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
with get_httpx_client() as client:
r = client.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")

View File

@ -45,6 +45,9 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
def get_ChatOpenAI(
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
model_name: str,
temperature: float,
max_tokens: int = None,
@ -61,18 +64,21 @@ def get_ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
openai_api_base=endpoint_host if endpoint_host else "None",
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
**kwargs
)
return model
def get_OpenAI(
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
model_name: str,
temperature: float,
max_tokens: int = None,
@ -82,19 +88,18 @@ def get_OpenAI(
verbose: bool = True,
**kwargs: Any,
) -> OpenAI:
config = get_model_worker_config(model_name)
if model_name == "openai-api":
model_name = config.get("model_name")
# TODO: 从API获取模型信息
model = OpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
openai_api_base=endpoint_host if endpoint_host else "None",
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
echo=echo,
**kwargs
)
@ -365,69 +370,14 @@ def get_model_worker_config(model_name: str = None) -> dict:
'''
from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
from configs.server_config import FSCHAT_MODEL_WORKERS
from server import model_workers
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True
if provider := config.get("provider"):
try:
config["worker_class"] = getattr(model_workers, provider)
except Exception as e:
msg = f"在线模型 {model_name} 的provider没有正确配置"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
# 本地模型
if model_name in MODEL_PATH["llm_model"]:
path = get_model_path(model_name)
config["model_path"] = path
if path and os.path.isdir(path):
config["model_path_exists"] = True
config["device"] = llm_device(config.get("device"))
return config
def get_all_model_worker_configs() -> dict:
result = {}
model_names = set(FSCHAT_MODEL_WORKERS.keys())
for name in model_names:
if name != "default":
result[name] = get_model_worker_config(name)
return result
def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER
host = FSCHAT_CONTROLLER["host"]
if host == "0.0.0.0":
host = "127.0.0.1"
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str:
if model := get_model_worker_config(model_name): # TODO: depends fastchat
host = model["host"]
if host == "0.0.0.0":
host = "127.0.0.1"
port = model["port"]
return f"http://{host}:{port}"
return ""
def fschat_openai_api_address() -> str:
from configs.server_config import FSCHAT_OPENAI_API
host = FSCHAT_OPENAI_API["host"]
if host == "0.0.0.0":
host = "127.0.0.1"
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}/v1"
def api_address() -> str:
from configs.server_config import API_SERVER
@ -461,6 +411,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
def set_httpx_config(
timeout: float = HTTPX_DEFAULT_TIMEOUT,
proxy: Union[str, Dict] = None,
unused_proxies: List[str] = [],
):
'''
设置httpx默认timeouthttpx默认timeout是5秒在请求LLM回答时不够用
@ -498,11 +449,7 @@ def set_httpx_config(
"http://localhost",
]
# do not use proxy for user deployed fastchat servers
for x in [
fschat_controller_address(),
fschat_model_worker_address(),
fschat_openai_api_address(),
]:
for x in unused_proxies:
host = ":".join(x.split(":")[:2])
if host not in no_proxy:
no_proxy.append(host)
@ -568,6 +515,7 @@ def get_httpx_client(
use_async: bool = False,
proxies: Union[str, Dict] = None,
timeout: float = HTTPX_DEFAULT_TIMEOUT,
unused_proxies: List[str] = [],
**kwargs,
) -> Union[httpx.Client, httpx.AsyncClient]:
'''
@ -579,11 +527,7 @@ def get_httpx_client(
"all://localhost": None,
}
# do not use proxy for user deployed fastchat servers
for x in [
fschat_controller_address(),
fschat_model_worker_address(),
fschat_openai_api_address(),
]:
for x in unused_proxies:
host = ":".join(x.split(":")[:2])
default_proxies.update({host: None})
@ -629,8 +573,6 @@ def get_server_configs() -> Dict:
获取configs中的原始配置项供前端使用
'''
_custom = {
"controller_address": fschat_controller_address(),
"openai_api_address": fschat_openai_api_address(),
"api_address": api_address(),
}
@ -638,14 +580,8 @@ def get_server_configs() -> Dict:
def list_online_embed_models() -> List[str]:
from server import model_workers
ret = []
for k, v in list_config_llm_models()["online"].items():
if provider := v.get("provider"):
worker_class = getattr(model_workers, provider, None)
if worker_class is not None and worker_class.can_embedding():
ret.append(k)
# TODO: 从在线API获取支持的模型列表
return ret

View File

@ -25,17 +25,11 @@ from configs import (
LLM_MODEL_CONFIG,
EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API,
FSCHAT_MODEL_WORKERS,
API_SERVER,
WEBUI_SERVER,
HTTPX_DEFAULT_TIMEOUT,
)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, get_httpx_client,
get_model_worker_config,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
from server.utils import (FastAPI, embedding_device)
from server.knowledge_base.migrate import create_tables
import argparse
from typing import List, Dict
@ -49,234 +43,6 @@ for model_category in LLM_MODEL_CONFIG.values():
all_model_names_list = list(all_model_names)
def create_controller_app(
dispatch_method: str,
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller, logger
logger.setLevel(log_level)
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
app._controller = controller
return app
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
"""
kwargs包含的字段如下
host:
port:
model_names:[`model_name`]
controller_address:
worker_address:
对于Langchain支持的模型
langchain_model:True
不会使用fschat
对于online_api:
online_api:True
worker_class: `provider`
对于离线模型
model_path: `model_name_or_path`,huggingface的repo-id或本地路径
device:`LLM_DEVICE`
"""
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args([])
for k, v in kwargs.items():
setattr(args, k, v)
if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
from fastchat.serve.base_model_worker import app
worker = ""
# 在线模型API
elif worker_class := kwargs.get("worker_class"):
from fastchat.serve.base_model_worker import app
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
# 本地模型
else:
from configs.model_config import VLLM_MODEL_DICT
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto'
args.trust_remote_code = True
args.download_dir = None
args.load_format = 'auto'
args.dtype = 'auto'
args.seed = 0
args.worker_use_ray = False
args.pipeline_parallel_size = 1
args.tensor_parallel_size = 1
args.block_size = 16
args.swap_space = 4 # GiB
args.gpu_memory_utilization = 0.90
args.max_num_batched_tokens = None # 一个批次中的最大令牌tokens数量这个取决于你的显卡和大模型设置设置太大显存会不够
args.max_num_seqs = 256
args.disable_log_stats = False
args.conv_template = None
args.limit_worker_concurrency = 5
args.no_register = False
args.num_gpus = 4 # vllm worker的切分是tensor并行这里填写显卡的数量
args.engine_use_ray = False
args.disable_log_requests = False
# 0.2.2 vllm后要加的参数, 但是这里不需要
args.max_model_len = None
args.revision = None
args.quantization = None
args.max_log_len = None
args.tokenizer_revision = None
# 0.2.2 vllm需要新加的参数
args.max_paddings = 256
if args.model_path:
args.model = args.model_path
if args.num_gpus > 1:
args.tensor_parallel_size = args.num_gpus
for k, v in kwargs.items():
setattr(args, k, v)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
worker = VLLMWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
llm_engine=engine,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量
args.load_8bit = False
args.cpu_offloading = None
args.gptq_ckpt = None
args.gptq_wbits = 16
args.gptq_groupsize = -1
args.gptq_act_order = False
args.awq_ckpt = None
args.awq_wbits = 16
args.awq_groupsize = -1
args.model_names = [""]
args.conv_template = None
args.limit_worker_concurrency = 5
args.stream_interval = 2
args.no_register = False
args.embed_in_truncate = False
for k, v in kwargs.items():
setattr(args, k, v)
if args.gpus:
if args.num_gpus is None:
args.num_gpus = len(args.gpus.split(','))
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
# sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})"
app._worker = worker
return app
def create_openai_api_app(
controller_address: str,
api_keys: List = [],
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
from fastchat.utils import build_logger
logger = build_logger("openai_api", "openai_api.log")
logger.setLevel(log_level)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
sys.modules["fastchat.serve.openai_api_server"].logger = logger
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup")
@ -285,154 +51,6 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None):
started_event.set()
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import httpx
from fastapi import Body
import time
import sys
from server.utils import set_httpx_config
set_httpx_config()
app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level,
)
_set_app_event(app, started_event)
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
available_models = app._controller.list_models()
if new_model_name in available_models:
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
logger.info(msg)
return {"code": 500, "msg": msg}
if new_model_name:
logger.info(f"开始切换LLM模型{model_name}{new_model_name}")
else:
logger.info(f"即将停止LLM模型 {model_name}")
if model_name not in available_models:
msg = f"the model {model_name} is not available"
logger.error(msg)
return {"code": 500, "msg": msg}
worker_address = app._controller.get_worker_address(model_name)
if not worker_address:
msg = f"can not find model_worker address for {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
with get_httpx_client() as client:
r = client.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
break
time.sleep(1)
timer -= 1
if timer > 0:
msg = f"sucess change model from {model_name} to {new_model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
else:
msg = f"failed change model from {model_name} to {new_model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
else:
msg = f"sucess to release model: {model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_model_worker(
model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])),
controller_address: str = "",
log_level: str = "INFO",
q: mp.Queue = None,
started_event: mp.Event = None,
):
import uvicorn
from fastapi import Body
import sys
from server.utils import set_httpx_config
set_httpx_config()
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
port = kwargs.pop("port")
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_event(app, started_event)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put([model_name, "start", new_model_name])
else:
if new_model_name:
q.put([model_name, "replace", new_model_name])
else:
q.put([model_name, "stop", None])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import sys
from server.utils import set_httpx_config
set_httpx_config()
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level)
_set_app_event(app, started_event)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port)
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
from server.api import create_app
import uvicorn
@ -473,6 +91,18 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None):
p.wait()
def run_loom(started_event: mp.Event = None):
from configs import LOOM_CONFIG
cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local",
"-f", LOOM_CONFIG
]
p = subprocess.Popen(cmd)
started_event.set()
p.wait()
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
@ -488,57 +118,14 @@ def parse_args() -> argparse.ArgumentParser:
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--llm-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers",
dest="llm_api",
)
parser.add_argument(
"-o",
"--openai-api",
action="store_true",
help="run fastchat's controller/openai_api servers",
dest="openai_api",
)
parser.add_argument(
"-m",
"--model-worker",
action="store_true",
help="run fastchat's model_worker server with specified model name. "
"specify --model-name if not using default llm models",
dest="model_worker",
)
parser.add_argument(
"-n",
"--model-name",
type=str,
nargs="+",
default=all_model_names_list,
help="specify model name for model worker. "
"add addition names with space seperated to start multiple model workers.",
dest="model_name",
)
parser.add_argument(
"-c",
"--controller",
type=str,
help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
dest="controller_address",
)
parser.add_argument(
"--api",
action="store_true",
help="run api.py server",
dest="api",
)
parser.add_argument(
"-p",
"--api-worker",
action="store_true",
help="run online model api such as zhipuai",
dest="api_worker",
)
parser.add_argument(
"-w",
"--webui",
@ -546,13 +133,6 @@ def parse_args() -> argparse.ArgumentParser:
help="run webui.py server",
dest="webui",
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="减少fastchat服务log信息",
dest="quiet",
)
parser.add_argument(
"-i",
"--lite",
@ -578,24 +158,13 @@ def dump_server_info(after_start=False, args=None):
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
models = list(LLM_MODEL_CONFIG['llm_model'].keys())
if args and args.model_name:
models = args.model_name
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
print(f"当前启动的LLM模型{models} @ {llm_device()}")
for model in models:
pprint(get_model_worker_config(model))
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
@ -625,7 +194,6 @@ async def start_main_server():
manager = mp.Manager()
run_mode = None
queue = manager.Queue()
args, parser = parse_args()
if args.all_webui:
@ -662,69 +230,16 @@ async def start_main_server():
processes = {"online_api": {}, "model_worker": {}}
def process_count():
return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2
if args.quiet or not log_verbose:
log_level = "ERROR"
else:
log_level = "INFO"
controller_started = manager.Event()
if args.openai_api:
process = Process(
target=run_controller,
name=f"controller",
kwargs=dict(log_level=log_level, started_event=controller_started),
daemon=True,
)
processes["controller"] = process
process = Process(
target=run_openai_api,
name=f"openai_api",
daemon=True,
)
processes["openai_api"] = process
model_worker_started = []
if args.model_worker:
for model_name in args.model_name:
config = get_model_worker_config(model_name)
if not config.get("online_api"):
e = manager.Event()
model_worker_started.append(e)
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name}",
kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["model_worker"][model_name] = process
if args.api_worker:
for model_name in args.model_name:
config = get_model_worker_config(model_name)
if (config.get("online_api")
and config.get("worker_class")
and model_name in FSCHAT_MODEL_WORKERS):
e = manager.Event()
model_worker_started.append(e)
process = Process(
target=run_model_worker,
name=f"api_worker - {model_name}",
kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["online_api"][model_name] = process
return len(processes)
loom_started = manager.Event()
process = Process(
target=run_loom,
name=f"run_loom Server",
kwargs=dict(started_event=loom_started),
daemon=True,
)
processes["run_loom"] = process
api_started = manager.Event()
if args.api:
process = Process(
@ -750,25 +265,10 @@ async def start_main_server():
else:
try:
# 保证任务收到SIGINT后能够正常退出
if p := processes.get("controller"):
if p := processes.get("run_loom"):
p.start()
p.name = f"{p.name} ({p.pid})"
controller_started.wait() # 等待controller启动完成
if p := processes.get("openai_api"):
p.start()
p.name = f"{p.name} ({p.pid})"
for n, p in processes.get("model_worker", {}).items():
p.start()
p.name = f"{p.name} ({p.pid})"
for n, p in processes.get("online_api", []).items():
p.start()
p.name = f"{p.name} ({p.pid})"
for e in model_worker_started:
e.wait()
loom_started.wait() # 等待Loom启动完成
if p := processes.get("api"):
p.start()
@ -782,74 +282,10 @@ async def start_main_server():
dump_server_info(after_start=True, args=args)
while True:
cmd = queue.get() # 收到切换模型的消息
e = manager.Event()
if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd
if cmd == "start": # 运行新模型
logger.info(f"准备启动新模型进程:{new_model_name}")
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
logger.info(f"成功启动新模型进程:{new_model_name}")
elif cmd == "stop":
if process := processes["model_worker"].get(model_name):
time.sleep(1)
process.terminate()
process.join()
logger.info(f"停止模型进程:{model_name}")
else:
logger.error(f"未找到模型进程:{model_name}")
elif cmd == "replace":
if process := processes["model_worker"].pop(model_name, None):
logger.info(f"停止模型进程:{model_name}")
start_time = datetime.now()
time.sleep(1)
process.terminate()
process.join()
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
timing = datetime.now() - start_time
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}")
else:
logger.error(f"未找到模型进程:{model_name}")
# 等待所有进程退出
if p := processes.get("webui"):
# for process in processes.get("model_worker", {}).values():
# process.join()
# for process in processes.get("online_api", {}).values():
# process.join()
# for name, process in processes.items():
# if name not in ["model_worker", "online_api"]:
# if isinstance(p, dict):
# for work_process in p.values():
# work_process.join()
# else:
# process.join()
p.join()
except Exception as e:
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")

View File

@ -13,21 +13,6 @@ from typing import List
api = ApiRequest()
def test_get_default_llm():
llm = api.get_default_llm_model()
print(llm)
assert isinstance(llm, tuple)
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
def test_server_configs():
configs = api.get_server_configs()
pprint(configs, depth=2)
assert isinstance(configs, dict)
assert len(configs) > 0
@pytest.mark.parametrize("type", ["llm_chat"])
def test_get_prompt_template(type):

View File

@ -8,13 +8,13 @@ from server.model_workers.base import *
from server.utils import get_model_worker_config, list_config_llm_models
from pprint import pprint
import pytest
workers = []
for x in list_config_llm_models()["online"]:
if x in ONLINE_LLM_MODEL and x not in workers:
workers.append(x)
print(f"all workers to test: {workers}")
#
#
# workers = []
# for x in list_config_llm_models()["online"]:
# if x in ONLINE_LLM_MODEL and x not in workers:
# workers.append(x)
# print(f"all workers to test: {workers}")
# workers = ["fangzhou-api"]
@ -56,4 +56,4 @@ def test_embeddings(worker):
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))
print("向量长度:", len(embeddings[0]))

View File

@ -1,4 +1,6 @@
import streamlit as st
from webui_pages.openai_plugins import openai_plugins_page
from webui_pages.utils import *
from streamlit_option_menu import option_menu
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
@ -22,9 +24,26 @@ if __name__ == "__main__":
'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat',
'Report a bug': "https://github.com/chatchat-space/Langchain-Chatchat/issues",
'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}"""
}
},
layout="wide"
)
# use the following code to set the app to wide mode and the html markdown to increase the sidebar width
st.markdown(
"""
<style>
[data-testid="stSidebar"][aria-expanded="true"] > div:first-child{
width: 350px;
}
[data-testid="stSidebar"][aria-expanded="false"] > div:first-child{
width: 600px;
margin-left: -600px;
}
""",
unsafe_allow_html=True,
)
pages = {
"对话": {
"icon": "chat",
@ -34,6 +53,10 @@ if __name__ == "__main__":
"icon": "hdd-stack",
"func": knowledge_base_page,
},
"模型服务": {
"icon": "hdd-stack",
"func": openai_plugins_page,
},
}
with st.sidebar:

View File

@ -1,8 +1,11 @@
import base64
import streamlit as st
from streamlit_antd_components.utils import ParseItems
from webui_pages.dialogue.utils import process_files
from webui_pages.loom_view_client import build_plugins_name, find_menu_items_by_index, set_llm_select, \
get_select_model_endpoint
from webui_pages.utils import *
from streamlit_chatbox import *
from streamlit_modal import Modal
@ -10,12 +13,15 @@ from datetime import datetime
import os
import re
import time
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG, OPENAI_KEY, OPENAI_PROXY)
from server.callback_handler.agent_callback_handler import AgentStatus
from server.utils import MsgType
import uuid
from typing import List, Dict
import streamlit_antd_components as sac
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
@ -105,14 +111,8 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.session_state.setdefault("conversation_ids", {})
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
st.session_state.setdefault("file_chat_id", None)
default_model = api.get_default_llm_model()[0]
if not chat_box.chat_inited:
st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前运行的模型`{default_model}`, 您可以开始提问了."
)
chat_box.init_session()
st.session_state.setdefault("select_plugins_info", None)
st.session_state.setdefault("select_model_worker", None)
# 弹出自定义命令帮助信息
modal = Modal("自定义命令", key="cmd_help", max_width="500")
@ -131,57 +131,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
# def on_mode_change():
# mode = st.session_state.dialogue_mode
# text = f"已切换到 {mode} 模式。"
# st.toast(text)
with st.expander("模型选择"):
plugins_menu = build_plugins_name()
# dialogue_modes = ["智能对话",
# "文件对话",
# ]
# dialogue_mode = st.selectbox("请选择对话模式:",
# dialogue_modes,
# index=0,
# on_change=on_mode_change,
# key="dialogue_mode",
# )
items, _ = ParseItems(plugins_menu).multi()
def on_llm_change():
if llm_model:
config = api.get_model_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = st.session_state.llm_model
if len(plugins_menu) > 0:
def llm_model_format_func(x):
if x in running_models:
return f"{x} (Running)"
return x
running_models = list(api.list_running_models())
available_models = []
config_models = api.list_config_models()
if not is_lite:
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
if (v.get("model_path_exists")
and k not in running_models):
available_models.append(k)
for k, v in config_models.get("online", {}).items():
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
available_models.append(k)
llm_models = running_models + available_models
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
if cur_llm_model in llm_models:
index = llm_models.index(cur_llm_model)
else:
index = 0
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
key="llm_model",
)
llm_model_index = sac.menu(plugins_menu, index=1, return_index=True)
plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index)
set_llm_select(plugins_info, llm_model_worker)
else:
st.info("没有可用的插件")
# 传入后端的内容
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
@ -213,8 +174,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if is_selected:
selected_tool_configs[tool] = TOOL_CONFIG[tool]
llm_model = None
if st.session_state["select_model_worker"] is not None:
llm_model = st.session_state["select_model_worker"]['label']
if llm_model is not None:
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
@ -258,7 +223,8 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.rerun()
else:
history = get_messages_history(
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
model_config["llm_model"].get(next(iter(model_config["llm_model"])), {}).get("history_len", 1)
)
chat_box.user_say(prompt)
if files_upload:
if files_upload["images"]:
@ -277,10 +243,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = ""
text_action = ""
element_index = 0
openai_config = {}
endpoint_host, select_model_name = get_select_model_endpoint()
openai_config["endpoint_host"] = endpoint_host
openai_config["model_name"] = select_model_name
openai_config["endpoint_host_key"] = OPENAI_KEY
openai_config["endpoint_host_proxy"] = OPENAI_PROXY
for d in api.chat_chat(query=prompt,
metadata=files_upload,
history=history,
model_config=model_config,
openai_config=openai_config,
conversation_id=conversation_id,
tool_config=selected_tool_configs,
):

View File

@ -3,6 +3,7 @@ import base64
import os
from io import BytesIO
def encode_file_to_base64(file):
# 将文件内容转换为 Base64 编码
buffer = BytesIO()

View File

@ -0,0 +1,138 @@
from typing import Tuple, Any
import streamlit as st
from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient
client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False)
def update_store():
st.session_state.status = client.status()
list_plugins = client.list_plugins()
st.session_state.run_plugins_list = list_plugins.get("plugins_list", [])
launch_subscribe_info = {}
for plugin_name in st.session_state.run_plugins_list:
launch_subscribe_info[plugin_name] = client.launch_subscribe_info(plugin_name)
st.session_state.launch_subscribe_info = launch_subscribe_info
list_running_models = {}
for plugin_name in st.session_state.run_plugins_list:
list_running_models[plugin_name] = client.list_running_models(plugin_name)
st.session_state.list_running_models = list_running_models
model_config = {}
for plugin_name in st.session_state.run_plugins_list:
model_config[plugin_name] = client.list_llm_models(plugin_name)
st.session_state.model_config = model_config
def start_plugin():
import time
start_plugins_name = st.session_state.plugins_name
if start_plugins_name in st.session_state.run_plugins_list:
st.toast(start_plugins_name + " has already been counted.")
time.sleep(.5)
else:
st.toast("start_plugin " + start_plugins_name + ",starting.")
result = client.launch_subscribe(start_plugins_name)
st.toast("start_plugin "+start_plugins_name + " ." + result.get("detail", ""))
time.sleep(3)
result1 = client.launch_subscribe_start(start_plugins_name)
st.toast("start_plugin "+start_plugins_name + " ." + result1.get("detail", ""))
time.sleep(2)
update_store()
def start_worker():
select_plugins_name = st.session_state.plugins_name
select_worker_id = st.session_state.worker_id
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
already_counted = False
for model in start_model_list:
if model['worker_id'] == select_worker_id:
already_counted = True
break
if already_counted:
st.toast(
"select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has already been counted.")
import time
time.sleep(.5)
else:
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " starting.")
result = client.launch_subscribe_start_model(select_plugins_name, select_worker_id)
st.toast("start worker_id " + select_worker_id + " ." + result.get("detail", ""))
import time
time.sleep(.5)
update_store()
def stop_worker():
select_plugins_name = st.session_state.plugins_name
select_worker_id = st.session_state.worker_id
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
already_counted = False
for model in start_model_list:
if model['worker_id'] == select_worker_id:
already_counted = True
break
if not already_counted:
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has bad already")
import time
time.sleep(.5)
else:
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " stopping.")
result = client.launch_subscribe_stop_model(select_plugins_name, select_worker_id)
st.toast("stop worker_id " + select_worker_id + " ." + result.get("detail", ""))
import time
time.sleep(.5)
update_store()
def build_plugins_name():
import streamlit_antd_components as sac
if "run_plugins_list" not in st.session_state:
return []
# 按照模型构建sac.menu(菜单
menu_items = []
for key, value in st.session_state.list_running_models.items():
menu_item_children = []
for model in value:
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
return menu_items
def find_menu_items_by_index(menu_items, key):
for menu_item in menu_items:
if menu_item.get('children') is not None:
for child in menu_item.get('children'):
if child.get('key') == key:
return menu_item, child
return None, None
def set_llm_select(plugins_info, llm_model_worker):
st.session_state["select_plugins_info"] = plugins_info
st.session_state["select_model_worker"] = llm_model_worker
def get_select_model_endpoint() -> Tuple[str, str]:
plugins_info = st.session_state["select_plugins_info"]
llm_model_worker = st.session_state["select_model_worker"]
if plugins_info is None or llm_model_worker is None:
raise ValueError("plugins_info or llm_model_worker is None")
plugins_name = st.session_state["select_plugins_info"]['label']
select_model_name = st.session_state["select_model_worker"]['label']
adapter_description = st.session_state.launch_subscribe_info[plugins_name]
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
return endpoint_host, select_model_name

View File

@ -0,0 +1 @@
from .base import openai_plugins_page

View File

@ -0,0 +1,75 @@
import streamlit as st
from loom_openai_plugins_frontend import loom_openai_plugins_frontend
from webui_pages.utils import ApiRequest
from webui_pages.loom_view_client import (
update_store,
start_plugin,
start_worker,
stop_worker,
)
def openai_plugins_page(api: ApiRequest, is_lite: bool = None):
if "status" not in st.session_state \
or "run_plugins_list" not in st.session_state \
or "launch_subscribe_info" not in st.session_state \
or "list_running_models" not in st.session_state \
or "model_config" not in st.session_state:
update_store()
with (st.container()):
if "worker_id" not in st.session_state:
st.session_state.worker_id = ''
if "plugins_name" not in st.session_state and "status" in st.session_state:
for k, v in st.session_state.status.get("status", {}).get("subscribe_status", []).items():
st.session_state.plugins_name = v.get("plugins_names", [])[0]
break
col1, col2 = st.columns([0.8, 0.2])
with col1:
event = loom_openai_plugins_frontend(plugins_status=st.session_state.status,
run_list_plugins=st.session_state.run_plugins_list,
launch_subscribe_info=st.session_state.launch_subscribe_info,
list_running_models=st.session_state.list_running_models,
model_config=st.session_state.model_config)
with col2:
st.write("操作")
if not st.session_state.run_plugins_list:
button_type_disabled = False
button_start_text = '启动'
else:
button_type_disabled = True
button_start_text = '已启动'
if event:
event_type = event.get("event")
if event_type == "BottomNavigationAction":
st.session_state.plugins_name = event.get("data")
st.session_state.worker_id = ''
# 不存在run_plugins_list打开启动按钮
if st.session_state.plugins_name not in st.session_state.run_plugins_list \
or st.session_state.run_plugins_list:
button_type_disabled = False
button_start_text = '启动'
else:
button_type_disabled = True
button_start_text = '已启动'
if event_type == "CardCoverComponent":
st.session_state.worker_id = event.get("data")
st.button(button_start_text, disabled=button_type_disabled, key="start",
on_click=start_plugin)
if st.session_state.worker_id != '':
st.button("启动" + st.session_state.worker_id, key="start_worker",
on_click=start_worker)
st.button("停止" + st.session_state.worker_id, key="stop_worker",
on_click=stop_worker)

View File

@ -266,6 +266,7 @@ class ApiRequest:
history: List[Dict] = [],
stream: bool = True,
model_config: Dict = None,
openai_config: Dict = None,
tool_config: Dict = None,
**kwargs,
):
@ -280,6 +281,7 @@ class ApiRequest:
"history": history,
"stream": stream,
"model_config": model_config,
"openai_config": openai_config,
"tool_config": tool_config,
}
@ -619,219 +621,6 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
# LLM模型相关操作
def list_running_models(
self,
controller_address: str = None,
):
'''
获取Fastchat中正运行的模型列表
'''
data = {
"controller_address": controller_address,
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {data}')
response = self.post(
"/llm_model/list_running_models",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
'''
从服务器上获取当前运行的LLM模型
local_first=True 优先返回运行中的本地模型否则优先按 LLM_MODEL_CONFIG['llm_model']配置顺序返回
返回类型为model_name, is_local_model
'''
def ret_sync():
running_models = self.list_running_models()
if not running_models:
return "", False
model = ""
for m in LLM_MODEL_CONFIG['llm_model']:
if m not in running_models:
continue
is_local = not running_models[m].get("online_api")
if local_first and not is_local:
continue
else:
model = m
break
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
async def ret_async():
running_models = await self.list_running_models()
if not running_models:
return "", False
model = ""
for m in LLM_MODEL_CONFIG['llm_model']:
if m not in running_models:
continue
is_local = not running_models[m].get("online_api")
if local_first and not is_local:
continue
else:
model = m
break
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
if self._use_async:
return ret_async()
else:
return ret_sync()
def list_config_models(
self,
types: List[str] = ["local", "online"],
) -> Dict[str, Dict]:
'''
获取服务器configs中配置的模型列表返回形式为{"type": {model_name: config}, ...}
'''
data = {
"types": types,
}
response = self.post(
"/llm_model/list_config_models",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def get_model_config(
self,
model_name: str = None,
) -> Dict:
'''
获取服务器上模型配置
'''
data = {
"model_name": model_name,
}
response = self.post(
"/llm_model/get_model_config",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def stop_llm_model(
self,
model_name: str,
controller_address: str = None,
):
'''
停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
data = {
"model_name": model_name,
"controller_address": controller_address,
}
response = self.post(
"/llm_model/stop",
json=data,
)
return self._get_response_value(response, as_json=True)
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
):
'''
向fastchat controller请求切换LLM模型
'''
if not model_name or not new_model_name:
return {
"code": 500,
"msg": f"未指定模型名称"
}
def ret_sync():
running_models = self.list_running_models()
if new_model_name == model_name or new_model_name in running_models:
return {
"code": 200,
"msg": "无需切换"
}
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = self.list_config_models()
if new_model_name not in config_models.get("local", {}):
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
response = self.post(
"/llm_model/change",
json=data,
)
return self._get_response_value(response, as_json=True)
async def ret_async():
running_models = await self.list_running_models()
if new_model_name == model_name or new_model_name in running_models:
return {
"code": 200,
"msg": "无需切换"
}
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = await self.list_config_models()
if new_model_name not in config_models.get("local", {}):
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
response = self.post(
"/llm_model/change",
json=data,
)
return self._get_response_value(response, as_json=True)
if self._use_async:
return ret_async()
else:
return ret_sync()
def embed_texts(
self,
texts: List[str],