兼容model_providers,集成webui及API中平台配置的初始化 (#3625)

* provider_configuration init of MODEL_PLATFORMS

* 开发手册

* 兼容model_providers,集成webui及API中平台配置的初始化
This commit is contained in:
glide-the 2024-04-03 12:16:38 +08:00 committed by GitHub
parent c0634828a4
commit b3dee0b1d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 510 additions and 599 deletions

3
.gitignore vendored
View File

@ -183,4 +183,5 @@ configs/*.py
/knowledge_base/samples/content/imi_temeplate.txt
/chatchat/configs/*.py
/chatchat/configs/*.yaml
chatchat/data
chatchat/data
/chatchat-server/chatchat/configs/model_providers.yaml

View File

@ -1,26 +0,0 @@
log_path: "logs"
log_level: "DEBUG"
api_server:
host: "127.0.0.1"
port: 8000
publish_server:
host: "127.0.0.1"
port: 8001
subscribe_server:
host: "127.0.0.1"
port: 8002
openai_plugins_folder:
- "openai_plugins"
openai_plugins_load_folder:
- "configs"
plugins:
- openai:
name: "openai"
- zhipuai:
name: "zhipuai"

View File

@ -1,6 +1,5 @@
import os
# 默认选用的 LLM 名称
DEFAULT_LLM_MODEL = "chatglm3-6b"
@ -31,7 +30,7 @@ SUPPORT_AGENT_MODELS = [
LLM_MODEL_CONFIG = {
# 意图识别不需要输出,模型后台知道就行
# 意图识别不需要输出,模型后台知道就行
"preprocess_model": {
DEFAULT_LLM_MODEL: {
"temperature": 0.05,
@ -57,7 +56,7 @@ LLM_MODEL_CONFIG = {
"prompt_name": "ChatGLM3",
"callbacks": True
},
},
},
"postprocess_model": {
DEFAULT_LLM_MODEL: {
"temperature": 0.01,
@ -76,47 +75,15 @@ LLM_MODEL_CONFIG = {
},
}
# 可以通过 loom/xinference/oneapi/fastchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力启动后下面变量会自动增加相应的平台
# ### 如果您已经有了一个openai endpoint的能力的地址可以在这里直接配置
# - platform_name 可以任意填写,不要重复即可
# - platform_type 可选openai, xinference, oneapi, fastchat。以后可能根据平台类型做一些功能区分
# - platform_type 以后可能根据平台类型做一些功能区分,与platform_name一致即可
# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。
MODEL_PLATFORMS = [
# {
# "platform_name": "openai-api",
# "platform_type": "openai",
# "api_base_url": "https://api.openai.com/v1",
# "api_key": "sk-",
# "api_proxy": "",
# "api_concurrencies": 5,
# "llm_models": [
# "gpt-3.5-turbo",
# ],
# "embed_models": [],
# "image_models": [],
# "multimodal_models": [],
# },
{
"platform_name": "xinference",
"platform_type": "xinference",
"api_base_url": "http://127.0.0.1:9997/v1",
"api_key": "EMPTY",
"api_concurrencies": 5,
# 注意:这里填写的是 xinference 部署的模型 UID而非模型名称
"llm_models": [
"chatglm3-6b",
],
"embed_models": [
"bge-large-zh-v1.5",
],
"image_models": [
"sd-turbo",
],
"multimodal_models": [
"qwen-vl",
],
},
# 创建一个全局的共享字典
MODEL_PLATFORMS = [
{
"platform_name": "oneapi",
@ -152,41 +119,13 @@ MODEL_PLATFORMS = [
"multimodal_models": [],
},
{
"platform_name": "ollama",
"platform_type": "ollama",
"api_base_url": "http://{host}:{port}/v1",
"api_key": "sk-",
"api_concurrencies": 5,
"llm_models": [
# Qwen API,其它更多模型请参考https://ollama.com/library
"qwen:7b",
],
"embed_models": [
# 必须升级ollama到0.1.29以上,低版本向量服务有问题
"nomic-embed-text"
],
"image_models": [],
"multimodal_models": [],
},
# {
# "platform_name": "loom",
# "platform_type": "loom",
# "api_base_url": "http://127.0.0.1:7860/v1",
# "api_key": "",
# "api_concurrencies": 5,
# "llm_models": [
# "chatglm3-6b",
# ],
# "embed_models": [],
# "image_models": [],
# "multimodal_models": [],
# },
]
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_providers.yaml")
MODEL_PROVIDERS_CFG_HOST = "127.0.0.1"
MODEL_PROVIDERS_CFG_PORT = 20000
# 工具配置项
TOOL_CONFIG = {
"search_local_knowledgebase": {

View File

@ -0,0 +1,29 @@
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
provider_credential:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
xinference:
model_credential:
- model: 'chatglm3-6b'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'

View File

@ -1,6 +0,0 @@
{
"openai_plugins": [
"imitater", "openai"
]
}

View File

@ -0,0 +1,109 @@
from typing import List, Dict
from chatchat.configs import MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, MODEL_PROVIDERS_CFG_PATH_CONFIG
from model_providers import BootstrapWebBuilder
from model_providers.bootstrap_web.entities.model_provider_entities import ProviderResponse
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers.core.provider_manager import ProviderManager
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
import multiprocessing as mp
import asyncio
import logging
logger = logging.getLogger(__name__)
def init_server(model_platforms_shard: Dict,
started_event: mp.Event = None,
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG,
provider_host: str = MODEL_PROVIDERS_CFG_HOST,
provider_port: int = MODEL_PROVIDERS_CFG_PORT,
log_path: str = "logs"
) -> None:
logging_conf = get_config_dict(
"DEBUG",
get_log_file(log_path=log_path, sub_dir=f"provider_{get_timestamp_ms()}"),
122,
111,
)
try:
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path=model_providers_cfg_path
)
.host(host=provider_host)
.port(port=provider_port)
.build()
)
boot.set_app_event(started_event=started_event)
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
model_platforms_shard['provider_platforms'] = provider_platforms
boot.serve(logging_conf=logging_conf)
async def pool_join_thread():
await boot.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise
def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
provider_list: List[ProviderResponse] = ProvidersWrapper(
provider_manager=provider_manager).get_provider_list()
logger.info(f"Provider list: {provider_list}")
# 转换MODEL_PLATFORMS
provider_platforms = []
for provider in provider_list:
provider_dict = {
"platform_name": provider.provider,
"platform_type": provider.provider,
"api_base_url": f"http://127.0.0.1:20000/{provider.provider}/v1",
"api_key": "EMPTY",
"api_concurrencies": 5
}
provider_dict["llm_models"] = []
provider_dict["embed_models"] = []
provider_dict["image_models"] = []
provider_dict["multimodal_models"] = []
supported_model_str_types = [model_type.to_origin_model_type() for model_type in
provider.supported_model_types]
for model_type in supported_model_str_types:
providers_model_type = ProvidersWrapper(
provider_manager=provider_manager
).get_models_by_model_type(model_type=model_type)
cur_model_type: List[str] = []
# 查询当前provider的模型
for provider_model in providers_model_type:
if provider_model.provider == provider.provider:
models = [model.model for model in provider_model.models]
cur_model_type.extend(models)
if cur_model_type:
if model_type == "text-generation":
provider_dict["llm_models"] = cur_model_type
elif model_type == "text-embedding":
provider_dict["embed_models"] = cur_model_type
elif model_type == "text2img":
provider_dict["image_models"] = cur_model_type
elif model_type == "multimodal":
provider_dict["multimodal_models"] = cur_model_type
else:
logger.warning(f"Unsupported model type: {model_type}")
provider_platforms.append(provider_dict)
logger.info(f"Provider platforms: {provider_platforms}")
return provider_platforms

View File

@ -27,9 +27,10 @@ from typing import (
import logging
from chatchat.configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT,
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE)
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE,
MODEL_PLATFORMS)
from chatchat.server.pydantic_v2 import BaseModel, Field
from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used?
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -47,17 +48,18 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
def get_config_platforms() -> Dict[str, Dict]:
import importlib
from chatchat.configs import model_config
importlib.reload(model_config)
# import importlib
# 不能支持重载
# from chatchat.configs import model_config
# importlib.reload(model_config)
return {m["platform_name"]: m for m in model_config.MODEL_PLATFORMS}
return {m["platform_name"]: m for m in MODEL_PLATFORMS}
def get_config_models(
model_name: str = None,
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
platform_name: str = None,
model_name: str = None,
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
platform_name: str = None,
) -> Dict[str, Dict]:
'''
获取配置的模型列表返回值为:
@ -71,12 +73,13 @@ def get_config_models(
"api_proxy": xx,
}}
'''
import importlib
from chatchat.configs import model_config
importlib.reload(model_config)
# import importlib
# 不能支持重载
# from chatchat.configs import model_config
# importlib.reload(model_config)
result = {}
for m in model_config.MODEL_PLATFORMS:
for m in MODEL_PLATFORMS:
if platform_name is not None and platform_name != m.get("platform_name"):
continue
if model_type is not None and f"{model_type}_models" not in m:
@ -124,24 +127,24 @@ def get_ChatOpenAI(
streaming: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
local_wrap: bool = False, # use local wrapped api
local_wrap: bool = False, # use local wrapped api
**kwargs: Any,
) -> ChatOpenAI:
model_info = get_model_info(model_name)
params = dict(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
)
try:
if local_wrap:
params.update(
openai_api_base = f"{api_address()}/v1",
openai_api_key = "EMPTY",
openai_api_base=f"{api_address()}/v1",
openai_api_key="EMPTY",
)
else:
params.update(
@ -164,7 +167,7 @@ def get_OpenAI(
echo: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
local_wrap: bool = False, # use local wrapped api
local_wrap: bool = False, # use local wrapped api
**kwargs: Any,
) -> OpenAI:
# TODO: 从API获取模型信息
@ -182,8 +185,8 @@ def get_OpenAI(
try:
if local_wrap:
params.update(
openai_api_base = f"{api_address()}/v1",
openai_api_key = "EMPTY",
openai_api_base=f"{api_address()}/v1",
openai_api_key="EMPTY",
)
else:
params.update(
@ -199,20 +202,20 @@ def get_OpenAI(
def get_Embeddings(
embed_model: str = DEFAULT_EMBEDDING_MODEL,
local_wrap: bool = False, # use local wrapped api
embed_model: str = DEFAULT_EMBEDDING_MODEL,
local_wrap: bool = False, # use local wrapped api
) -> Embeddings:
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154
model_info = get_model_info(model_name=embed_model)
params = dict(model=embed_model)
try:
if local_wrap:
params.update(
openai_api_base = f"{api_address()}/v1",
openai_api_key = "EMPTY",
openai_api_base=f"{api_address()}/v1",
openai_api_key="EMPTY",
)
else:
params.update(
@ -223,7 +226,7 @@ def get_Embeddings(
if model_info.get("platform_type") == "openai":
return OpenAIEmbeddings(**params)
elif model_info.get("platform_type") == "ollama":
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''),
return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1', ''),
model=embed_model,
)
else:
@ -233,9 +236,9 @@ def get_Embeddings(
def get_OpenAIClient(
platform_name: str=None,
model_name: str=None,
is_async: bool=True,
platform_name: str = None,
model_name: str = None,
is_async: bool = True,
) -> Union[openai.Client, openai.AsyncClient]:
'''
construct an openai Client for specified platform or model
@ -601,7 +604,7 @@ def run_in_process_pool(
tasks = []
max_workers = None
if sys.platform.startswith("win"):
max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows
max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows
with ProcessPoolExecutor(max_workers=max_workers) as pool:
for kwargs in params:
tasks.append(pool.submit(func, **kwargs))

View File

@ -1,4 +1,5 @@
import asyncio
import multiprocessing
from contextlib import asynccontextmanager
import multiprocessing as mp
import os
@ -6,6 +7,7 @@ import subprocess
import sys
from multiprocessing import Process
from chatchat.model_loaders.init_server import init_server
# 设置numexpr最大线程数默认为CPU核心数
try:
@ -23,7 +25,7 @@ from chatchat.configs import (
DEFAULT_EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
API_SERVER,
WEBUI_SERVER,
WEBUI_SERVER, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT
)
from chatchat.server.utils import FastAPI
from chatchat.server.knowledge_base.migrate import create_tables
@ -38,15 +40,34 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None):
if started_event is not None:
started_event.set()
yield
app.router.lifespan_context = lifespan
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
def run_init_server(
model_platforms_shard: Dict,
started_event: mp.Event = None,
run_mode: str = None,
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG,
provider_host: str = MODEL_PROVIDERS_CFG_HOST,
provider_port: int = MODEL_PROVIDERS_CFG_PORT):
init_server(model_platforms_shard=model_platforms_shard,
started_event=started_event,
model_providers_cfg_path=model_providers_cfg_path,
provider_host=provider_host,
provider_port=provider_port)
def run_api_server(model_platforms_shard: Dict,
started_event: mp.Event = None,
run_mode: str = None):
from chatchat.server.api_server.server_app import create_app
import uvicorn
from chatchat.server.utils import set_httpx_config
from chatchat.configs import MODEL_PLATFORMS
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}")
set_httpx_config()
app = create_app(run_mode=run_mode)
_set_app_event(app, started_event)
@ -56,48 +77,65 @@ def run_api_server(started_event: mp.Event = None, run_mode: str = None):
uvicorn.run(app, host=host, port=port)
def run_webui(started_event: mp.Event = None, run_mode: str = None):
def run_webui(model_platforms_shard: Dict,
started_event: mp.Event = None, run_mode: str = None):
import sys
from chatchat.server.utils import set_httpx_config
from chatchat.configs import MODEL_PLATFORMS
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
logger.info(f"Webui MODEL_PLATFORMS: {MODEL_PLATFORMS}")
set_httpx_config()
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
# 判断系统是否为Windows
if sys.platform == "win32":
st_exe = os.path.join(os.path.dirname(sys.executable), "Scripts", "streamlit")
else:
st_exe = os.path.join(os.path.dirname(sys.executable),"streamlit")
script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'webui.py')
cmd = [st_exe, "run", script_dir,
"--server.address", host,
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
]
flag_options = {'server_address': host,
'server_port': port,
'theme_base': 'light',
'theme_primaryColor': '#165dff',
'theme_secondaryBackgroundColor': '#f5f5f5',
'theme_textColor': '#000000',
'global_disableWatchdogWarning': None,
'global_disableWidgetStateDuplicationWarning': None,
'global_showWarningOnDirectExecution': None,
'global_developmentMode': None, 'global_logLevel': None, 'global_unitTest': None,
'global_suppressDeprecationWarnings': None, 'global_minCachedMessageSize': None,
'global_maxCachedMessageAge': None, 'global_storeCachedForwardMessagesInMemory': None,
'global_dataFrameSerialization': None, 'logger_level': None, 'logger_messageFormat': None,
'logger_enableRich': None, 'client_caching': None, 'client_displayEnabled': None,
'client_showErrorDetails': None, 'client_toolbarMode': None, 'client_showSidebarNavigation': None,
'runner_magicEnabled': None, 'runner_installTracer': None, 'runner_fixMatplotlib': None,
'runner_postScriptGC': None, 'runner_fastReruns': None,
'runner_enforceSerializableSessionState': None, 'runner_enumCoercion': None,
'server_folderWatchBlacklist': None, 'server_fileWatcherType': None, 'server_headless': None,
'server_runOnSave': None, 'server_allowRunOnSave': None, 'server_scriptHealthCheckEnabled': None,
'server_baseUrlPath': None, 'server_enableCORS': None, 'server_enableXsrfProtection': None,
'server_maxUploadSize': None, 'server_maxMessageSize': None, 'server_enableArrowTruncation': None,
'server_enableWebsocketCompression': None, 'server_enableStaticServing': None,
'browser_serverAddress': None, 'browser_gatherUsageStats': None, 'browser_serverPort': None,
'server_sslCertFile': None, 'server_sslKeyFile': None, 'ui_hideTopBar': None,
'ui_hideSidebarNav': None, 'magic_displayRootDocString': None,
'magic_displayLastExprIfNoSemicolon': None, 'deprecation_showfileUploaderEncoding': None,
'deprecation_showImageFormat': None, 'deprecation_showPyplotGlobalUse': None,
'theme_backgroundColor': None, 'theme_font': None}
args = []
if run_mode == "lite":
cmd += [
args += [
"--",
"lite",
]
p = subprocess.Popen(cmd)
try:
# for streamlit >= 1.12.1
from streamlit.web import bootstrap
except ImportError:
from streamlit import bootstrap
bootstrap.run(script_dir, False, args, flag_options)
started_event.set()
p.wait()
def run_loom(started_event: mp.Event = None):
from chatchat.configs import LOOM_CONFIG
cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local",
"-f", LOOM_CONFIG
]
p = subprocess.Popen(cmd)
started_event.set()
p.wait()
def parse_args() -> argparse.ArgumentParser:
@ -106,13 +144,13 @@ def parse_args() -> argparse.ArgumentParser:
"-a",
"--all-webui",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
help="run model_providers servers,run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
help="run model_providers servers, run api.py",
dest="all_api",
)
@ -156,11 +194,18 @@ def dump_server_info(after_start=False, args=None):
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
print(f"当前Embbedings模型 {DEFAULT_EMBEDDING_MODEL}")
print(f"默认选用的 Embedding 名称 {DEFAULT_EMBEDDING_MODEL}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.api:
print(
f" Chatchat Model providers Server: model_providers_cfg_path_config:{MODEL_PROVIDERS_CFG_PATH_CONFIG}\n"
f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n"
f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n")
print(f" Chatchat Api Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
@ -193,21 +238,16 @@ async def start_main_server():
args, parser = parse_args()
if args.all_webui:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = False
if args.lite:
args.model_worker = False
run_mode = "lite"
dump_server_info(args=args)
@ -216,25 +256,29 @@ async def start_main_server():
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {"online_api": {}, "model_worker": {}}
processes = {}
def process_count():
return len(processes)
loom_started = manager.Event()
# process = Process(
# target=run_loom,
# name=f"run_loom Server",
# kwargs=dict(started_event=loom_started),
# daemon=True,
# )
# processes["run_loom"] = process
# 定义全局配置变量,使用 Manager 创建共享字典
model_platforms_shard = manager.dict()
model_providers_started = manager.Event()
if args.api:
process = Process(
target=run_init_server,
name=f"Model providers Server",
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started,
run_mode=run_mode),
daemon=True,
)
processes["model_providers"] = process
api_started = manager.Event()
if args.api:
process = Process(
target=run_api_server,
name=f"API Server",
kwargs=dict(started_event=api_started, run_mode=run_mode),
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode),
daemon=True,
)
processes["api"] = process
@ -244,7 +288,7 @@ async def start_main_server():
process = Process(
target=run_webui,
name=f"WEBUI Server",
kwargs=dict(started_event=webui_started, run_mode=run_mode),
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=webui_started, run_mode=run_mode),
daemon=True,
)
processes["webui"] = process
@ -254,10 +298,10 @@ async def start_main_server():
else:
try:
# 保证任务收到SIGINT后能够正常退出
if p := processes.get("run_loom"):
if p := processes.get("model_providers"):
p.start()
p.name = f"{p.name} ({p.pid})"
loom_started.wait() # 等待Loom启动完成
model_providers_started.wait() # 等待model_providers启动完成
if p := processes.get("api"):
p.start()
@ -295,6 +339,8 @@ async def start_main_server():
if __name__ == "__main__":
# 添加这行代码
multiprocessing.freeze_support()
create_tables()
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()

View File

@ -1,184 +0,0 @@
from typing import Tuple, Any
import streamlit as st
from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient
import logging
logger = logging.getLogger(__name__)
client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False)
def update_store():
logger.info("update_status")
st.session_state.status = client.status()
logger.info("update_list_plugins")
list_plugins = client.list_plugins()
st.session_state.run_plugins_list = list_plugins.get("plugins_list", [])
logger.info("update_launch_subscribe_info")
launch_subscribe_info = {}
for plugin_name in st.session_state.run_plugins_list:
launch_subscribe_info[plugin_name] = client.launch_subscribe_info(plugin_name)
st.session_state.launch_subscribe_info = launch_subscribe_info
logger.info("update_list_running_models")
list_running_models = {}
for plugin_name in st.session_state.run_plugins_list:
list_running_models[plugin_name] = client.list_running_models(plugin_name)
st.session_state.list_running_models = list_running_models
logger.info("update_model_config")
model_config = {}
for plugin_name in st.session_state.run_plugins_list:
model_config[plugin_name] = client.list_llm_models(plugin_name)
st.session_state.model_config = model_config
def start_plugin():
import time
start_plugins_name = st.session_state.plugins_name
if start_plugins_name in st.session_state.run_plugins_list:
st.toast(start_plugins_name + " has already been counted.")
time.sleep(.5)
else:
st.toast("start_plugin " + start_plugins_name + ",starting.")
result = client.launch_subscribe(start_plugins_name)
st.toast("start_plugin " + start_plugins_name + " ." + result.get("detail", ""))
time.sleep(3)
result1 = client.launch_subscribe_start(start_plugins_name)
st.toast("start_plugin " + start_plugins_name + " ." + result1.get("detail", ""))
time.sleep(2)
update_store()
def start_worker():
select_plugins_name = st.session_state.plugins_name
select_worker_id = st.session_state.worker_id
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
already_counted = False
for model in start_model_list:
if model['worker_id'] == select_worker_id:
already_counted = True
break
if already_counted:
st.toast(
"select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has already been counted.")
import time
time.sleep(.5)
else:
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " starting.")
result = client.launch_subscribe_start_model(select_plugins_name, select_worker_id)
st.toast("start worker_id " + select_worker_id + " ." + result.get("detail", ""))
import time
time.sleep(.5)
update_store()
def stop_worker():
select_plugins_name = st.session_state.plugins_name
select_worker_id = st.session_state.worker_id
start_model_list = st.session_state.list_running_models.get(select_plugins_name, [])
already_counted = False
for model in start_model_list:
if model['worker_id'] == select_worker_id:
already_counted = True
break
if not already_counted:
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has bad already")
import time
time.sleep(.5)
else:
st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " stopping.")
result = client.launch_subscribe_stop_model(select_plugins_name, select_worker_id)
st.toast("stop worker_id " + select_worker_id + " ." + result.get("detail", ""))
import time
time.sleep(.5)
update_store()
def build_providers_model_plugins_name():
import streamlit_antd_components as sac
if "run_plugins_list" not in st.session_state:
return []
# 按照模型构建sac.menu(菜单
menu_items = []
for key, value in st.session_state.list_running_models.items():
menu_item_children = []
for model in value:
if "model" in model["providers"]:
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
return menu_items
def build_providers_embedding_plugins_name():
import streamlit_antd_components as sac
if "run_plugins_list" not in st.session_state:
return []
# 按照模型构建sac.menu(菜单
menu_items = []
for key, value in st.session_state.list_running_models.items():
menu_item_children = []
for model in value:
if "embedding" in model["providers"]:
menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"]))
menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children))
return menu_items
def find_menu_items_by_index(menu_items, key):
for menu_item in menu_items:
if menu_item.get('children') is not None:
for child in menu_item.get('children'):
if child.get('key') == key:
return menu_item, child
return None, None
def set_llm_select(plugins_info, llm_model_worker):
st.session_state["select_plugins_info"] = plugins_info
st.session_state["select_model_worker"] = llm_model_worker
def get_select_model_endpoint() -> Tuple[str, str]:
plugins_info = st.session_state["select_plugins_info"]
llm_model_worker = st.session_state["select_model_worker"]
if plugins_info is None or llm_model_worker is None:
raise ValueError("plugins_info or llm_model_worker is None")
plugins_name = st.session_state["select_plugins_info"]['label']
select_model_name = st.session_state["select_model_worker"]['label']
adapter_description = st.session_state.launch_subscribe_info[plugins_name]
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
return endpoint_host, select_model_name
def set_embed_select(plugins_info, embed_model_worker):
st.session_state["select_embed_plugins_info"] = plugins_info
st.session_state["select_embed_model_worker"] = embed_model_worker
def get_select_embed_endpoint() -> Tuple[str, str]:
select_embed_plugins_info = st.session_state["select_embed_plugins_info"]
select_embed_model_worker = st.session_state["select_embed_model_worker"]
if select_embed_plugins_info is None or select_embed_model_worker is None:
raise ValueError("select_embed_plugins_info or select_embed_model_worker is None")
embed_plugins_name = st.session_state["select_embed_plugins_info"]['label']
select_embed_model_name = st.session_state["select_embed_model_worker"]['label']
endpoint_host = None
if embed_plugins_name in st.session_state.launch_subscribe_info:
adapter_description = st.session_state.launch_subscribe_info[embed_plugins_name]
endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "")
return endpoint_host, select_embed_model_name

View File

@ -1 +0,0 @@
from .base import openai_plugins_page

View File

@ -1,67 +0,0 @@
import streamlit as st
from loom_openai_plugins_frontend import loom_openai_plugins_frontend
from chatchat.webui_pages.utils import ApiRequest
from chatchat.webui_pages.loom_view_client import (
update_store,
start_plugin,
start_worker,
stop_worker,
)
def openai_plugins_page(api: ApiRequest, is_lite: bool = None):
with (st.container()):
if "worker_id" not in st.session_state:
st.session_state.worker_id = ''
if "plugins_name" not in st.session_state and "status" in st.session_state:
for k, v in st.session_state.status.get("status", {}).get("subscribe_status", []).items():
st.session_state.plugins_name = v.get("plugins_names", [])[0]
break
col1, col2 = st.columns([0.8, 0.2])
with col1:
event = loom_openai_plugins_frontend(plugins_status=st.session_state.status,
run_list_plugins=st.session_state.run_plugins_list,
launch_subscribe_info=st.session_state.launch_subscribe_info,
list_running_models=st.session_state.list_running_models,
model_config=st.session_state.model_config)
with col2:
st.write("操作")
if not st.session_state.run_plugins_list:
button_type_disabled = False
button_start_text = '启动'
else:
button_type_disabled = True
button_start_text = '已启动'
if event:
event_type = event.get("event")
if event_type == "BottomNavigationAction":
st.session_state.plugins_name = event.get("data")
st.session_state.worker_id = ''
# 不存在run_plugins_list打开启动按钮
if st.session_state.plugins_name not in st.session_state.run_plugins_list \
or st.session_state.run_plugins_list:
button_type_disabled = False
button_start_text = '启动'
else:
button_type_disabled = True
button_start_text = '已启动'
if event_type == "CardCoverComponent":
st.session_state.worker_id = event.get("data")
st.button(button_start_text, disabled=button_type_disabled, key="start",
on_click=start_plugin)
if st.session_state.worker_id != '':
st.button("启动" + st.session_state.worker_id, key="start_worker",
on_click=start_worker)
st.button("停止" + st.session_state.worker_id, key="stop_worker",
on_click=stop_worker)

View File

@ -216,7 +216,7 @@ build-backend = "poetry.core.masonry.api"
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [

View File

@ -0,0 +1,6 @@
from chatchat.model_loaders.init_server import init_server
def test_init_server():
init_server()

View File

@ -0,0 +1,47 @@
import argparse
import asyncio
import logging
from model_providers import BootstrapWebBuilder
from model_providers.core.utils.utils import get_config_dict, get_log_file, get_timestamp_ms
logger = logging.getLogger(__name__)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-providers",
type=str,
default="D:\\project\\Langchain-Chatchat\\model-providers\\model_providers.yaml",
help="run model_providers servers",
dest="model_providers",
)
args = parser.parse_args()
try:
logging_conf = get_config_dict(
"DEBUG",
get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"),
122,
111,
)
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path=args.model_providers
)
.host(host="127.0.0.1")
.port(port=20000)
.build()
)
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
async def pool_join_thread():
await boot.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise

View File

@ -45,6 +45,7 @@ from model_providers.core.bootstrap.openai_protocol import (
Role,
UsageInfo,
)
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
@ -111,7 +112,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
@ -170,7 +171,7 @@ def _create_template_from_message_type(
def _convert_to_message(
message: MessageLikeRepresentation,
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
@ -212,7 +213,7 @@ def _convert_to_message(
async def _stream_openai_chat_completion(
response: Generator,
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
@ -362,11 +363,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
started_event.set()
async def workspaces_model_providers(self, request: Request):
provider_list = self.get_provider_list(model_type=request.get("model_type"))
provider_list = ProvidersWrapper(provider_manager=self._provider_manager.provider_manager).get_provider_list(
model_type=request.get("model_type"))
return ProviderListResponse(data=provider_list)
async def workspaces_model_types(self, model_type: str, request: Request):
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
models_by_model_type = ProvidersWrapper(
provider_manager=self._provider_manager.provider_manager).get_models_by_model_type(model_type=model_type)
return ProviderModelTypeResponse(data=models_by_model_type)
async def list_models(self, provider: str, request: Request):
@ -399,7 +403,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return ModelList(data=models_list)
async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
@ -409,7 +413,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
@ -469,9 +473,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def run(
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try:

View File

@ -4,22 +4,11 @@ from typing import List, Optional
from fastapi import Request
from model_providers.bootstrap_web.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
ModelResponse,
ProviderResponse,
ProviderWithModelsResponse,
SystemConfigurationResponse,
)
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest,
EmbeddingsRequest,
)
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.entities.provider_entities import ProviderType
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
class Bootstrap:
@ -43,137 +32,6 @@ class Bootstrap:
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
def get_provider_list(
self, model_type: Optional[str] = None
) -> List[ProviderResponse]:
"""
get provider list.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if (
model_type_entity
not in provider_configuration.provider.supported_model_types
):
continue
provider_response = ProviderResponse(
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=ProviderType.value_of("custom"),
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(enabled=False),
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_model_type(
self, model_type: str
) -> List[ProviderWithModelsResponse]:
"""
get models by model type.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
# Get provider available models
models = provider_configurations.get_models(
model_type=ModelType.value_of(model_type)
)
# Group models by provider
provider_models = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
has_active_models = any(
[model.status == ModelStatus.ACTIVE for model in models]
)
providers_with_models.append(
ProviderWithModelsResponse(
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE
if has_active_models
else CustomConfigurationStatus.NO_CONFIGURE,
models=[
ModelResponse(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
)
for model in models
],
)
)
return providers_with_models
@classmethod
@abstractmethod
def from_config(cls, cfg=None):

View File

@ -0,0 +1,153 @@
from typing import Optional, List
from model_providers.bootstrap_web.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
ModelResponse,
ProviderResponse,
ProviderWithModelsResponse,
SystemConfigurationResponse,
)
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.entities.provider_entities import ProviderType
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
class ProvidersWrapper:
def __init__(self, provider_manager: ProviderManager):
self.provider_manager = provider_manager
def get_provider_list(
self, model_type: Optional[str] = None
) -> List[ProviderResponse]:
"""
get provider list.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.get_configurations(provider=provider)
)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if (
model_type_entity
not in provider_configuration.provider.supported_model_types
):
continue
provider_response = ProviderResponse(
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=ProviderType.value_of("custom"),
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(enabled=False),
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_model_type(
self, model_type: str
) -> List[ProviderWithModelsResponse]:
"""
get models by model type.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.get_configurations(provider=provider)
)
# Get provider available models
models = provider_configurations.get_models(
model_type=ModelType.value_of(model_type)
)
# Group models by provider
provider_models = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
has_active_models = any(
[model.status == ModelStatus.ACTIVE for model in models]
)
providers_with_models.append(
ProviderWithModelsResponse(
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE
if has_active_models
else CustomConfigurationStatus.NO_CONFIGURE,
models=[
ModelResponse(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
)
for model in models
],
)
)
return providers_with_models