diff --git a/.gitignore b/.gitignore index 2d3f1173..b54db8e1 100644 --- a/.gitignore +++ b/.gitignore @@ -183,4 +183,5 @@ configs/*.py /knowledge_base/samples/content/imi_temeplate.txt /chatchat/configs/*.py /chatchat/configs/*.yaml -chatchat/data \ No newline at end of file +chatchat/data +/chatchat-server/chatchat/configs/model_providers.yaml diff --git a/chatchat-server/chatchat/configs/loom.yaml.example b/chatchat-server/chatchat/configs/loom.yaml.example deleted file mode 100644 index b3d9d1cb..00000000 --- a/chatchat-server/chatchat/configs/loom.yaml.example +++ /dev/null @@ -1,26 +0,0 @@ -log_path: "logs" -log_level: "DEBUG" - -api_server: - host: "127.0.0.1" - port: 8000 - -publish_server: - host: "127.0.0.1" - port: 8001 - -subscribe_server: - host: "127.0.0.1" - port: 8002 - -openai_plugins_folder: - - "openai_plugins" -openai_plugins_load_folder: - - "configs" - - -plugins: - - openai: - name: "openai" - - zhipuai: - name: "zhipuai" diff --git a/chatchat-server/chatchat/configs/model_config.py.example b/chatchat-server/chatchat/configs/model_config.py.example index a38f1419..0ed1950b 100644 --- a/chatchat-server/chatchat/configs/model_config.py.example +++ b/chatchat-server/chatchat/configs/model_config.py.example @@ -1,6 +1,5 @@ import os - # 默认选用的 LLM 名称 DEFAULT_LLM_MODEL = "chatglm3-6b" @@ -31,7 +30,7 @@ SUPPORT_AGENT_MODELS = [ LLM_MODEL_CONFIG = { -# 意图识别不需要输出,模型后台知道就行 + # 意图识别不需要输出,模型后台知道就行 "preprocess_model": { DEFAULT_LLM_MODEL: { "temperature": 0.05, @@ -57,7 +56,7 @@ LLM_MODEL_CONFIG = { "prompt_name": "ChatGLM3", "callbacks": True }, - }, + }, "postprocess_model": { DEFAULT_LLM_MODEL: { "temperature": 0.01, @@ -76,47 +75,15 @@ LLM_MODEL_CONFIG = { }, } -# 可以通过 loom/xinference/oneapi/fastchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。 +# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台 +# ### 如果您已经有了一个openai endpoint的能力的地址,可以在这里直接配置 # - platform_name 可以任意填写,不要重复即可 -# - platform_type 可选:openai, xinference, oneapi, fastchat。以后可能根据平台类型做一些功能区分 +# - platform_type 以后可能根据平台类型做一些功能区分,与platform_name一致即可 # - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。 -MODEL_PLATFORMS = [ - # { - # "platform_name": "openai-api", - # "platform_type": "openai", - # "api_base_url": "https://api.openai.com/v1", - # "api_key": "sk-", - # "api_proxy": "", - # "api_concurrencies": 5, - # "llm_models": [ - # "gpt-3.5-turbo", - # ], - # "embed_models": [], - # "image_models": [], - # "multimodal_models": [], - # }, - { - "platform_name": "xinference", - "platform_type": "xinference", - "api_base_url": "http://127.0.0.1:9997/v1", - "api_key": "EMPTY", - "api_concurrencies": 5, - # 注意:这里填写的是 xinference 部署的模型 UID,而非模型名称 - "llm_models": [ - "chatglm3-6b", - ], - "embed_models": [ - "bge-large-zh-v1.5", - ], - "image_models": [ - "sd-turbo", - ], - "multimodal_models": [ - "qwen-vl", - ], - }, +# 创建一个全局的共享字典 +MODEL_PLATFORMS = [ { "platform_name": "oneapi", @@ -152,41 +119,13 @@ MODEL_PLATFORMS = [ "multimodal_models": [], }, - { - "platform_name": "ollama", - "platform_type": "ollama", - "api_base_url": "http://{host}:{port}/v1", - "api_key": "sk-", - "api_concurrencies": 5, - "llm_models": [ - # Qwen API,其它更多模型请参考https://ollama.com/library - "qwen:7b", - ], - "embed_models": [ - # 必须升级ollama到0.1.29以上,低版本向量服务有问题 - "nomic-embed-text" - ], - "image_models": [], - "multimodal_models": [], - }, - # { - # "platform_name": "loom", - # "platform_type": "loom", - # "api_base_url": "http://127.0.0.1:7860/v1", - # "api_key": "", - # "api_concurrencies": 5, - # "llm_models": [ - # "chatglm3-6b", - # ], - # "embed_models": [], - # "image_models": [], - # "multimodal_models": [], - # }, ] -LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml") +MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_providers.yaml") +MODEL_PROVIDERS_CFG_HOST = "127.0.0.1" +MODEL_PROVIDERS_CFG_PORT = 20000 # 工具配置项 TOOL_CONFIG = { "search_local_knowledgebase": { diff --git a/chatchat-server/chatchat/configs/model_providers.yaml.example b/chatchat-server/chatchat/configs/model_providers.yaml.example new file mode 100644 index 00000000..d88736b3 --- /dev/null +++ b/chatchat-server/chatchat/configs/model_providers.yaml.example @@ -0,0 +1,29 @@ +openai: + model_credential: + - model: 'gpt-3.5-turbo' + model_type: 'llm' + model_credentials: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + - model: 'gpt-4' + model_type: 'llm' + model_credentials: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + + provider_credential: + openai_api_key: 'sk-' + openai_organization: '' + openai_api_base: '' + +xinference: + model_credential: + - model: 'chatglm3-6b' + model_type: 'llm' + model_credentials: + server_url: 'http://127.0.0.1:9997/' + model_uid: 'chatglm3-6b' + + diff --git a/chatchat-server/chatchat/configs/openai-plugins-list.json b/chatchat-server/chatchat/configs/openai-plugins-list.json deleted file mode 100644 index 88f2cb63..00000000 --- a/chatchat-server/chatchat/configs/openai-plugins-list.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "openai_plugins": [ - "imitater", "openai" - ] - -} diff --git a/chatchat-server/chatchat/model_loaders/init_server.py b/chatchat-server/chatchat/model_loaders/init_server.py new file mode 100644 index 00000000..ef909af6 --- /dev/null +++ b/chatchat-server/chatchat/model_loaders/init_server.py @@ -0,0 +1,109 @@ +from typing import List, Dict +from chatchat.configs import MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, MODEL_PROVIDERS_CFG_PATH_CONFIG +from model_providers import BootstrapWebBuilder +from model_providers.bootstrap_web.entities.model_provider_entities import ProviderResponse +from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper +from model_providers.core.provider_manager import ProviderManager +from model_providers.core.utils.utils import ( + get_config_dict, + get_log_file, + get_timestamp_ms, +) +import multiprocessing as mp +import asyncio +import logging + +logger = logging.getLogger(__name__) + + +def init_server(model_platforms_shard: Dict, + started_event: mp.Event = None, + model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG, + provider_host: str = MODEL_PROVIDERS_CFG_HOST, + provider_port: int = MODEL_PROVIDERS_CFG_PORT, + log_path: str = "logs" + ) -> None: + logging_conf = get_config_dict( + "DEBUG", + get_log_file(log_path=log_path, sub_dir=f"provider_{get_timestamp_ms()}"), + 122, + 111, + ) + + try: + boot = ( + BootstrapWebBuilder() + .model_providers_cfg_path( + model_providers_cfg_path=model_providers_cfg_path + ) + .host(host=provider_host) + .port(port=provider_port) + .build() + ) + boot.set_app_event(started_event=started_event) + + provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager) + model_platforms_shard['provider_platforms'] = provider_platforms + + boot.serve(logging_conf=logging_conf) + + async def pool_join_thread(): + await boot.join() + + asyncio.run(pool_join_thread()) + except SystemExit: + logger.info("SystemExit raised, exiting") + raise + + +def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]: + provider_list: List[ProviderResponse] = ProvidersWrapper( + provider_manager=provider_manager).get_provider_list() + logger.info(f"Provider list: {provider_list}") + # 转换MODEL_PLATFORMS + provider_platforms = [] + for provider in provider_list: + provider_dict = { + "platform_name": provider.provider, + "platform_type": provider.provider, + "api_base_url": f"http://127.0.0.1:20000/{provider.provider}/v1", + "api_key": "EMPTY", + "api_concurrencies": 5 + } + + provider_dict["llm_models"] = [] + provider_dict["embed_models"] = [] + provider_dict["image_models"] = [] + provider_dict["multimodal_models"] = [] + supported_model_str_types = [model_type.to_origin_model_type() for model_type in + provider.supported_model_types] + + for model_type in supported_model_str_types: + + providers_model_type = ProvidersWrapper( + provider_manager=provider_manager + ).get_models_by_model_type(model_type=model_type) + cur_model_type: List[str] = [] + # 查询当前provider的模型 + for provider_model in providers_model_type: + if provider_model.provider == provider.provider: + models = [model.model for model in provider_model.models] + cur_model_type.extend(models) + + if cur_model_type: + if model_type == "text-generation": + provider_dict["llm_models"] = cur_model_type + elif model_type == "text-embedding": + provider_dict["embed_models"] = cur_model_type + elif model_type == "text2img": + provider_dict["image_models"] = cur_model_type + elif model_type == "multimodal": + provider_dict["multimodal_models"] = cur_model_type + else: + logger.warning(f"Unsupported model type: {model_type}") + + provider_platforms.append(provider_dict) + + logger.info(f"Provider platforms: {provider_platforms}") + + return provider_platforms \ No newline at end of file diff --git a/chatchat-server/chatchat/server/utils.py b/chatchat-server/chatchat/server/utils.py index 9db43145..06a73f05 100644 --- a/chatchat-server/chatchat/server/utils.py +++ b/chatchat-server/chatchat/server/utils.py @@ -27,9 +27,10 @@ from typing import ( import logging from chatchat.configs import (logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, - DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE) + DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, TEMPERATURE, + MODEL_PLATFORMS) from chatchat.server.pydantic_v2 import BaseModel, Field -from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used? +from chatchat.server.minx_chat_openai import MinxChatOpenAI # TODO: still used? async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -47,17 +48,18 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): def get_config_platforms() -> Dict[str, Dict]: - import importlib - from chatchat.configs import model_config - importlib.reload(model_config) + # import importlib + # 不能支持重载 + # from chatchat.configs import model_config + # importlib.reload(model_config) - return {m["platform_name"]: m for m in model_config.MODEL_PLATFORMS} + return {m["platform_name"]: m for m in MODEL_PLATFORMS} def get_config_models( - model_name: str = None, - model_type: Literal["llm", "embed", "image", "multimodal"] = None, - platform_name: str = None, + model_name: str = None, + model_type: Literal["llm", "embed", "image", "multimodal"] = None, + platform_name: str = None, ) -> Dict[str, Dict]: ''' 获取配置的模型列表,返回值为: @@ -71,12 +73,13 @@ def get_config_models( "api_proxy": xx, }} ''' - import importlib - from chatchat.configs import model_config - importlib.reload(model_config) + # import importlib + # 不能支持重载 + # from chatchat.configs import model_config + # importlib.reload(model_config) result = {} - for m in model_config.MODEL_PLATFORMS: + for m in MODEL_PLATFORMS: if platform_name is not None and platform_name != m.get("platform_name"): continue if model_type is not None and f"{model_type}_models" not in m: @@ -124,24 +127,24 @@ def get_ChatOpenAI( streaming: bool = True, callbacks: List[Callable] = [], verbose: bool = True, - local_wrap: bool = False, # use local wrapped api + local_wrap: bool = False, # use local wrapped api **kwargs: Any, ) -> ChatOpenAI: model_info = get_model_info(model_name) params = dict( - streaming=streaming, - verbose=verbose, - callbacks=callbacks, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - **kwargs + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + **kwargs ) try: if local_wrap: params.update( - openai_api_base = f"{api_address()}/v1", - openai_api_key = "EMPTY", + openai_api_base=f"{api_address()}/v1", + openai_api_key="EMPTY", ) else: params.update( @@ -164,7 +167,7 @@ def get_OpenAI( echo: bool = True, callbacks: List[Callable] = [], verbose: bool = True, - local_wrap: bool = False, # use local wrapped api + local_wrap: bool = False, # use local wrapped api **kwargs: Any, ) -> OpenAI: # TODO: 从API获取模型信息 @@ -182,8 +185,8 @@ def get_OpenAI( try: if local_wrap: params.update( - openai_api_base = f"{api_address()}/v1", - openai_api_key = "EMPTY", + openai_api_base=f"{api_address()}/v1", + openai_api_key="EMPTY", ) else: params.update( @@ -199,20 +202,20 @@ def get_OpenAI( def get_Embeddings( - embed_model: str = DEFAULT_EMBEDDING_MODEL, - local_wrap: bool = False, # use local wrapped api + embed_model: str = DEFAULT_EMBEDDING_MODEL, + local_wrap: bool = False, # use local wrapped api ) -> Embeddings: from langchain_community.embeddings.openai import OpenAIEmbeddings from langchain_community.embeddings import OllamaEmbeddings - from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154 + from chatchat.server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154 model_info = get_model_info(model_name=embed_model) params = dict(model=embed_model) try: if local_wrap: params.update( - openai_api_base = f"{api_address()}/v1", - openai_api_key = "EMPTY", + openai_api_base=f"{api_address()}/v1", + openai_api_key="EMPTY", ) else: params.update( @@ -223,7 +226,7 @@ def get_Embeddings( if model_info.get("platform_type") == "openai": return OpenAIEmbeddings(**params) elif model_info.get("platform_type") == "ollama": - return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1',''), + return OllamaEmbeddings(base_url=model_info.get("api_base_url").replace('/v1', ''), model=embed_model, ) else: @@ -233,9 +236,9 @@ def get_Embeddings( def get_OpenAIClient( - platform_name: str=None, - model_name: str=None, - is_async: bool=True, + platform_name: str = None, + model_name: str = None, + is_async: bool = True, ) -> Union[openai.Client, openai.AsyncClient]: ''' construct an openai Client for specified platform or model @@ -601,7 +604,7 @@ def run_in_process_pool( tasks = [] max_workers = None if sys.platform.startswith("win"): - max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows + max_workers = min(mp.cpu_count(), 60) # max_workers should not exceed 60 on windows with ProcessPoolExecutor(max_workers=max_workers) as pool: for kwargs in params: tasks.append(pool.submit(func, **kwargs)) diff --git a/chatchat-server/chatchat/startup.py b/chatchat-server/chatchat/startup.py index caea2df3..b6249261 100644 --- a/chatchat-server/chatchat/startup.py +++ b/chatchat-server/chatchat/startup.py @@ -1,4 +1,5 @@ import asyncio +import multiprocessing from contextlib import asynccontextmanager import multiprocessing as mp import os @@ -6,6 +7,7 @@ import subprocess import sys from multiprocessing import Process +from chatchat.model_loaders.init_server import init_server # 设置numexpr最大线程数,默认为CPU核心数 try: @@ -23,7 +25,7 @@ from chatchat.configs import ( DEFAULT_EMBEDDING_MODEL, TEXT_SPLITTER_NAME, API_SERVER, - WEBUI_SERVER, + WEBUI_SERVER, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT ) from chatchat.server.utils import FastAPI from chatchat.server.knowledge_base.migrate import create_tables @@ -38,15 +40,34 @@ def _set_app_event(app: FastAPI, started_event: mp.Event = None): if started_event is not None: started_event.set() yield + app.router.lifespan_context = lifespan -def run_api_server(started_event: mp.Event = None, run_mode: str = None): +def run_init_server( + model_platforms_shard: Dict, + started_event: mp.Event = None, + run_mode: str = None, + model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG, + provider_host: str = MODEL_PROVIDERS_CFG_HOST, + provider_port: int = MODEL_PROVIDERS_CFG_PORT): + init_server(model_platforms_shard=model_platforms_shard, + started_event=started_event, + model_providers_cfg_path=model_providers_cfg_path, + provider_host=provider_host, + provider_port=provider_port) + + +def run_api_server(model_platforms_shard: Dict, + started_event: mp.Event = None, + run_mode: str = None): from chatchat.server.api_server.server_app import create_app import uvicorn from chatchat.server.utils import set_httpx_config + from chatchat.configs import MODEL_PLATFORMS + MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms']) + logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}") set_httpx_config() - app = create_app(run_mode=run_mode) _set_app_event(app, started_event) @@ -56,48 +77,65 @@ def run_api_server(started_event: mp.Event = None, run_mode: str = None): uvicorn.run(app, host=host, port=port) -def run_webui(started_event: mp.Event = None, run_mode: str = None): +def run_webui(model_platforms_shard: Dict, + started_event: mp.Event = None, run_mode: str = None): import sys from chatchat.server.utils import set_httpx_config - + from chatchat.configs import MODEL_PLATFORMS + MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms']) + logger.info(f"Webui MODEL_PLATFORMS: {MODEL_PLATFORMS}") set_httpx_config() host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] - # 判断系统是否为Windows - if sys.platform == "win32": - st_exe = os.path.join(os.path.dirname(sys.executable), "Scripts", "streamlit") - else: - st_exe = os.path.join(os.path.dirname(sys.executable),"streamlit") + script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'webui.py') - cmd = [st_exe, "run", script_dir, - "--server.address", host, - "--server.port", str(port), - "--theme.base", "light", - "--theme.primaryColor", "#165dff", - "--theme.secondaryBackgroundColor", "#f5f5f5", - "--theme.textColor", "#000000", - ] + + flag_options = {'server_address': host, + 'server_port': port, + 'theme_base': 'light', + 'theme_primaryColor': '#165dff', + 'theme_secondaryBackgroundColor': '#f5f5f5', + 'theme_textColor': '#000000', + 'global_disableWatchdogWarning': None, + 'global_disableWidgetStateDuplicationWarning': None, + 'global_showWarningOnDirectExecution': None, + 'global_developmentMode': None, 'global_logLevel': None, 'global_unitTest': None, + 'global_suppressDeprecationWarnings': None, 'global_minCachedMessageSize': None, + 'global_maxCachedMessageAge': None, 'global_storeCachedForwardMessagesInMemory': None, + 'global_dataFrameSerialization': None, 'logger_level': None, 'logger_messageFormat': None, + 'logger_enableRich': None, 'client_caching': None, 'client_displayEnabled': None, + 'client_showErrorDetails': None, 'client_toolbarMode': None, 'client_showSidebarNavigation': None, + 'runner_magicEnabled': None, 'runner_installTracer': None, 'runner_fixMatplotlib': None, + 'runner_postScriptGC': None, 'runner_fastReruns': None, + 'runner_enforceSerializableSessionState': None, 'runner_enumCoercion': None, + 'server_folderWatchBlacklist': None, 'server_fileWatcherType': None, 'server_headless': None, + 'server_runOnSave': None, 'server_allowRunOnSave': None, 'server_scriptHealthCheckEnabled': None, + 'server_baseUrlPath': None, 'server_enableCORS': None, 'server_enableXsrfProtection': None, + 'server_maxUploadSize': None, 'server_maxMessageSize': None, 'server_enableArrowTruncation': None, + 'server_enableWebsocketCompression': None, 'server_enableStaticServing': None, + 'browser_serverAddress': None, 'browser_gatherUsageStats': None, 'browser_serverPort': None, + 'server_sslCertFile': None, 'server_sslKeyFile': None, 'ui_hideTopBar': None, + 'ui_hideSidebarNav': None, 'magic_displayRootDocString': None, + 'magic_displayLastExprIfNoSemicolon': None, 'deprecation_showfileUploaderEncoding': None, + 'deprecation_showImageFormat': None, 'deprecation_showPyplotGlobalUse': None, + 'theme_backgroundColor': None, 'theme_font': None} + + args = [] if run_mode == "lite": - cmd += [ + args += [ "--", "lite", ] - p = subprocess.Popen(cmd) + + try: + # for streamlit >= 1.12.1 + from streamlit.web import bootstrap + except ImportError: + from streamlit import bootstrap + + bootstrap.run(script_dir, False, args, flag_options) started_event.set() - p.wait() - - -def run_loom(started_event: mp.Event = None): - from chatchat.configs import LOOM_CONFIG - - cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local", - "-f", LOOM_CONFIG - ] - - p = subprocess.Popen(cmd) - started_event.set() - p.wait() def parse_args() -> argparse.ArgumentParser: @@ -106,13 +144,13 @@ def parse_args() -> argparse.ArgumentParser: "-a", "--all-webui", action="store_true", - help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py", + help="run model_providers servers,run api.py and webui.py", dest="all_webui", ) parser.add_argument( "--all-api", action="store_true", - help="run fastchat's controller/openai_api/model_worker servers, run api.py", + help="run model_providers servers, run api.py", dest="all_api", ) @@ -156,11 +194,18 @@ def dump_server_info(after_start=False, args=None): print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}") - print(f"当前Embbedings模型: {DEFAULT_EMBEDDING_MODEL}") + print(f"默认选用的 Embedding 名称: {DEFAULT_EMBEDDING_MODEL}") if after_start: print("\n") print(f"服务端运行信息:") + if args.api: + print( + f" Chatchat Model providers Server: model_providers_cfg_path_config:{MODEL_PROVIDERS_CFG_PATH_CONFIG}\n" + f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n" + f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n") + + print(f" Chatchat Api Server: {api_address()}") if args.webui: print(f" Chatchat WEBUI Server: {webui_address()}") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) @@ -193,21 +238,16 @@ async def start_main_server(): args, parser = parse_args() if args.all_webui: - args.openai_api = True - args.model_worker = True args.api = True args.api_worker = True args.webui = True elif args.all_api: - args.openai_api = True - args.model_worker = True args.api = True args.api_worker = True args.webui = False if args.lite: - args.model_worker = False run_mode = "lite" dump_server_info(args=args) @@ -216,25 +256,29 @@ async def start_main_server(): logger.info(f"正在启动服务:") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") - processes = {"online_api": {}, "model_worker": {}} + processes = {} def process_count(): return len(processes) - loom_started = manager.Event() - # process = Process( - # target=run_loom, - # name=f"run_loom Server", - # kwargs=dict(started_event=loom_started), - # daemon=True, - # ) - # processes["run_loom"] = process + # 定义全局配置变量,使用 Manager 创建共享字典 + model_platforms_shard = manager.dict() + model_providers_started = manager.Event() + if args.api: + process = Process( + target=run_init_server, + name=f"Model providers Server", + kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started, + run_mode=run_mode), + daemon=True, + ) + processes["model_providers"] = process api_started = manager.Event() if args.api: process = Process( target=run_api_server, name=f"API Server", - kwargs=dict(started_event=api_started, run_mode=run_mode), + kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode), daemon=True, ) processes["api"] = process @@ -244,7 +288,7 @@ async def start_main_server(): process = Process( target=run_webui, name=f"WEBUI Server", - kwargs=dict(started_event=webui_started, run_mode=run_mode), + kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=webui_started, run_mode=run_mode), daemon=True, ) processes["webui"] = process @@ -254,10 +298,10 @@ async def start_main_server(): else: try: # 保证任务收到SIGINT后,能够正常退出 - if p := processes.get("run_loom"): + if p := processes.get("model_providers"): p.start() p.name = f"{p.name} ({p.pid})" - loom_started.wait() # 等待Loom启动完成 + model_providers_started.wait() # 等待model_providers启动完成 if p := processes.get("api"): p.start() @@ -295,6 +339,8 @@ async def start_main_server(): if __name__ == "__main__": + # 添加这行代码 + multiprocessing.freeze_support() create_tables() if sys.version_info < (3, 10): loop = asyncio.get_event_loop() diff --git a/chatchat-server/chatchat/webui_pages/loom_view_client.py b/chatchat-server/chatchat/webui_pages/loom_view_client.py deleted file mode 100644 index 6c19b6de..00000000 --- a/chatchat-server/chatchat/webui_pages/loom_view_client.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Tuple, Any - -import streamlit as st -from loom_core.openai_plugins.publish import LoomOpenAIPluginsClient -import logging - -logger = logging.getLogger(__name__) -client = LoomOpenAIPluginsClient(base_url="http://localhost:8000", timeout=300, use_async=False) - - -def update_store(): - logger.info("update_status") - st.session_state.status = client.status() - logger.info("update_list_plugins") - list_plugins = client.list_plugins() - st.session_state.run_plugins_list = list_plugins.get("plugins_list", []) - - logger.info("update_launch_subscribe_info") - launch_subscribe_info = {} - for plugin_name in st.session_state.run_plugins_list: - launch_subscribe_info[plugin_name] = client.launch_subscribe_info(plugin_name) - - st.session_state.launch_subscribe_info = launch_subscribe_info - - logger.info("update_list_running_models") - list_running_models = {} - for plugin_name in st.session_state.run_plugins_list: - list_running_models[plugin_name] = client.list_running_models(plugin_name) - st.session_state.list_running_models = list_running_models - - logger.info("update_model_config") - model_config = {} - for plugin_name in st.session_state.run_plugins_list: - model_config[plugin_name] = client.list_llm_models(plugin_name) - st.session_state.model_config = model_config - - -def start_plugin(): - import time - start_plugins_name = st.session_state.plugins_name - if start_plugins_name in st.session_state.run_plugins_list: - st.toast(start_plugins_name + " has already been counted.") - - time.sleep(.5) - else: - - st.toast("start_plugin " + start_plugins_name + ",starting.") - result = client.launch_subscribe(start_plugins_name) - st.toast("start_plugin " + start_plugins_name + " ." + result.get("detail", "")) - time.sleep(3) - result1 = client.launch_subscribe_start(start_plugins_name) - - st.toast("start_plugin " + start_plugins_name + " ." + result1.get("detail", "")) - time.sleep(2) - update_store() - - -def start_worker(): - select_plugins_name = st.session_state.plugins_name - select_worker_id = st.session_state.worker_id - start_model_list = st.session_state.list_running_models.get(select_plugins_name, []) - already_counted = False - for model in start_model_list: - if model['worker_id'] == select_worker_id: - already_counted = True - break - - if already_counted: - st.toast( - "select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has already been counted.") - import time - time.sleep(.5) - else: - - st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " starting.") - result = client.launch_subscribe_start_model(select_plugins_name, select_worker_id) - st.toast("start worker_id " + select_worker_id + " ." + result.get("detail", "")) - import time - time.sleep(.5) - update_store() - - -def stop_worker(): - select_plugins_name = st.session_state.plugins_name - select_worker_id = st.session_state.worker_id - start_model_list = st.session_state.list_running_models.get(select_plugins_name, []) - already_counted = False - for model in start_model_list: - if model['worker_id'] == select_worker_id: - already_counted = True - break - - if not already_counted: - st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " has bad already") - import time - time.sleep(.5) - else: - - st.toast("select_plugins_name " + select_plugins_name + ",worker_id " + select_worker_id + " stopping.") - result = client.launch_subscribe_stop_model(select_plugins_name, select_worker_id) - st.toast("stop worker_id " + select_worker_id + " ." + result.get("detail", "")) - import time - time.sleep(.5) - update_store() - - -def build_providers_model_plugins_name(): - import streamlit_antd_components as sac - if "run_plugins_list" not in st.session_state: - return [] - # 按照模型构建sac.menu(菜单 - menu_items = [] - for key, value in st.session_state.list_running_models.items(): - menu_item_children = [] - for model in value: - if "model" in model["providers"]: - menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"])) - - menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children)) - - return menu_items - - -def build_providers_embedding_plugins_name(): - import streamlit_antd_components as sac - if "run_plugins_list" not in st.session_state: - return [] - # 按照模型构建sac.menu(菜单 - menu_items = [] - for key, value in st.session_state.list_running_models.items(): - menu_item_children = [] - for model in value: - if "embedding" in model["providers"]: - menu_item_children.append(sac.MenuItem(model["model_name"], description=model["model_description"])) - - menu_items.append(sac.MenuItem(key, icon='box-fill', children=menu_item_children)) - - return menu_items - - -def find_menu_items_by_index(menu_items, key): - for menu_item in menu_items: - if menu_item.get('children') is not None: - for child in menu_item.get('children'): - if child.get('key') == key: - return menu_item, child - - return None, None - - -def set_llm_select(plugins_info, llm_model_worker): - st.session_state["select_plugins_info"] = plugins_info - st.session_state["select_model_worker"] = llm_model_worker - - -def get_select_model_endpoint() -> Tuple[str, str]: - plugins_info = st.session_state["select_plugins_info"] - llm_model_worker = st.session_state["select_model_worker"] - if plugins_info is None or llm_model_worker is None: - raise ValueError("plugins_info or llm_model_worker is None") - plugins_name = st.session_state["select_plugins_info"]['label'] - select_model_name = st.session_state["select_model_worker"]['label'] - adapter_description = st.session_state.launch_subscribe_info[plugins_name] - endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "") - return endpoint_host, select_model_name - - -def set_embed_select(plugins_info, embed_model_worker): - st.session_state["select_embed_plugins_info"] = plugins_info - st.session_state["select_embed_model_worker"] = embed_model_worker - - -def get_select_embed_endpoint() -> Tuple[str, str]: - select_embed_plugins_info = st.session_state["select_embed_plugins_info"] - select_embed_model_worker = st.session_state["select_embed_model_worker"] - if select_embed_plugins_info is None or select_embed_model_worker is None: - raise ValueError("select_embed_plugins_info or select_embed_model_worker is None") - embed_plugins_name = st.session_state["select_embed_plugins_info"]['label'] - select_embed_model_name = st.session_state["select_embed_model_worker"]['label'] - endpoint_host = None - if embed_plugins_name in st.session_state.launch_subscribe_info: - adapter_description = st.session_state.launch_subscribe_info[embed_plugins_name] - endpoint_host = adapter_description.get("adapter_description", {}).get("endpoint_host", "") - return endpoint_host, select_embed_model_name diff --git a/chatchat-server/chatchat/webui_pages/openai_plugins/__init__.py b/chatchat-server/chatchat/webui_pages/openai_plugins/__init__.py deleted file mode 100644 index 72a273ae..00000000 --- a/chatchat-server/chatchat/webui_pages/openai_plugins/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base import openai_plugins_page diff --git a/chatchat-server/chatchat/webui_pages/openai_plugins/base.py b/chatchat-server/chatchat/webui_pages/openai_plugins/base.py deleted file mode 100644 index 3d600285..00000000 --- a/chatchat-server/chatchat/webui_pages/openai_plugins/base.py +++ /dev/null @@ -1,67 +0,0 @@ -import streamlit as st -from loom_openai_plugins_frontend import loom_openai_plugins_frontend - -from chatchat.webui_pages.utils import ApiRequest -from chatchat.webui_pages.loom_view_client import ( - update_store, - start_plugin, - start_worker, - stop_worker, -) - - -def openai_plugins_page(api: ApiRequest, is_lite: bool = None): - - - with (st.container()): - - if "worker_id" not in st.session_state: - st.session_state.worker_id = '' - if "plugins_name" not in st.session_state and "status" in st.session_state: - - for k, v in st.session_state.status.get("status", {}).get("subscribe_status", []).items(): - st.session_state.plugins_name = v.get("plugins_names", [])[0] - break - - col1, col2 = st.columns([0.8, 0.2]) - - with col1: - event = loom_openai_plugins_frontend(plugins_status=st.session_state.status, - run_list_plugins=st.session_state.run_plugins_list, - launch_subscribe_info=st.session_state.launch_subscribe_info, - list_running_models=st.session_state.list_running_models, - model_config=st.session_state.model_config) - - with col2: - st.write("操作") - if not st.session_state.run_plugins_list: - button_type_disabled = False - button_start_text = '启动' - else: - button_type_disabled = True - button_start_text = '已启动' - - if event: - event_type = event.get("event") - if event_type == "BottomNavigationAction": - st.session_state.plugins_name = event.get("data") - st.session_state.worker_id = '' - # 不存在run_plugins_list,打开启动按钮 - if st.session_state.plugins_name not in st.session_state.run_plugins_list \ - or st.session_state.run_plugins_list: - button_type_disabled = False - button_start_text = '启动' - else: - button_type_disabled = True - button_start_text = '已启动' - if event_type == "CardCoverComponent": - st.session_state.worker_id = event.get("data") - - st.button(button_start_text, disabled=button_type_disabled, key="start", - on_click=start_plugin) - - if st.session_state.worker_id != '': - st.button("启动" + st.session_state.worker_id, key="start_worker", - on_click=start_worker) - st.button("停止" + st.session_state.worker_id, key="stop_worker", - on_click=stop_worker) diff --git a/chatchat-server/pyproject.toml b/chatchat-server/pyproject.toml index dabe058c..76fa1a12 100644 --- a/chatchat-server/pyproject.toml +++ b/chatchat-server/pyproject.toml @@ -216,7 +216,7 @@ build-backend = "poetry.core.masonry.api" # # https://github.com/tophat/syrupy # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. -addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" +addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -svv" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ diff --git a/chatchat-server/tests/unit_server/test_init_server.py b/chatchat-server/tests/unit_server/test_init_server.py new file mode 100644 index 00000000..f1fed4e1 --- /dev/null +++ b/chatchat-server/tests/unit_server/test_init_server.py @@ -0,0 +1,6 @@ +from chatchat.model_loaders.init_server import init_server + + +def test_init_server(): + + init_server() diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py new file mode 100644 index 00000000..99b6fd1a --- /dev/null +++ b/model-providers/model_providers/__main__.py @@ -0,0 +1,47 @@ +import argparse +import asyncio +import logging + +from model_providers import BootstrapWebBuilder +from model_providers.core.utils.utils import get_config_dict, get_log_file, get_timestamp_ms + +logger = logging.getLogger(__name__) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-providers", + type=str, + default="D:\\project\\Langchain-Chatchat\\model-providers\\model_providers.yaml", + help="run model_providers servers", + dest="model_providers", + ) + args = parser.parse_args() + try: + logging_conf = get_config_dict( + "DEBUG", + get_log_file(log_path="logs", sub_dir=f"local_{get_timestamp_ms()}"), + 122, + 111, + ) + boot = ( + BootstrapWebBuilder() + .model_providers_cfg_path( + model_providers_cfg_path=args.model_providers + ) + .host(host="127.0.0.1") + .port(port=20000) + .build() + ) + boot.set_app_event(started_event=None) + boot.serve(logging_conf=logging_conf) + + + async def pool_join_thread(): + await boot.join() + + + asyncio.run(pool_join_thread()) + except SystemExit: + logger.info("SystemExit raised, exiting") + raise diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 398b0bc9..39d91570 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -45,6 +45,7 @@ from model_providers.core.bootstrap.openai_protocol import ( Role, UsageInfo, ) +from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper from model_providers.core.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, @@ -111,7 +112,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: def _create_template_from_message_type( - message_type: str, template: Union[str, list] + message_type: str, template: Union[str, list] ) -> PromptMessage: """Create a message prompt template from a message type and template string. @@ -170,7 +171,7 @@ def _create_template_from_message_type( def _convert_to_message( - message: MessageLikeRepresentation, + message: MessageLikeRepresentation, ) -> Union[PromptMessage]: """Instantiate a message from a variety of message formats. @@ -212,7 +213,7 @@ def _convert_to_message( async def _stream_openai_chat_completion( - response: Generator, + response: Generator, ) -> AsyncGenerator[str, None]: request_id, model = None, None for chunk in response: @@ -362,11 +363,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): started_event.set() async def workspaces_model_providers(self, request: Request): - provider_list = self.get_provider_list(model_type=request.get("model_type")) + + provider_list = ProvidersWrapper(provider_manager=self._provider_manager.provider_manager).get_provider_list( + model_type=request.get("model_type")) return ProviderListResponse(data=provider_list) async def workspaces_model_types(self, model_type: str, request: Request): - models_by_model_type = self.get_models_by_model_type(model_type=model_type) + models_by_model_type = ProvidersWrapper( + provider_manager=self._provider_manager.provider_manager).get_models_by_model_type(model_type=model_type) return ProviderModelTypeResponse(data=models_by_model_type) async def list_models(self, provider: str, request: Request): @@ -399,7 +403,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return ModelList(data=models_list) async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" @@ -409,7 +413,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" @@ -469,9 +473,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: diff --git a/model-providers/model_providers/core/bootstrap/base.py b/model-providers/model_providers/core/bootstrap/base.py index f74c5dd8..4adf6f31 100644 --- a/model-providers/model_providers/core/bootstrap/base.py +++ b/model-providers/model_providers/core/bootstrap/base.py @@ -4,22 +4,11 @@ from typing import List, Optional from fastapi import Request -from model_providers.bootstrap_web.entities.model_provider_entities import ( - CustomConfigurationResponse, - CustomConfigurationStatus, - ModelResponse, - ProviderResponse, - ProviderWithModelsResponse, - SystemConfigurationResponse, -) from model_providers.core.bootstrap.openai_protocol import ( ChatCompletionRequest, EmbeddingsRequest, ) -from model_providers.core.entities.model_entities import ModelStatus -from model_providers.core.entities.provider_entities import ProviderType from model_providers.core.model_manager import ModelManager -from model_providers.core.model_runtime.entities.model_entities import ModelType class Bootstrap: @@ -43,137 +32,6 @@ class Bootstrap: def provider_manager(self, provider_manager: ModelManager): self._provider_manager = provider_manager - def get_provider_list( - self, model_type: Optional[str] = None - ) -> List[ProviderResponse]: - """ - get provider list. - - :param model_type: model type - :return: - """ - # 合并两个字典的键 - provider = set( - self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys() - ) - provider.update( - self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys() - ) - # Get all provider configurations of the current workspace - provider_configurations = ( - self.provider_manager.provider_manager.get_configurations(provider=provider) - ) - - provider_responses = [] - for provider_configuration in provider_configurations.values(): - if model_type: - model_type_entity = ModelType.value_of(model_type) - if ( - model_type_entity - not in provider_configuration.provider.supported_model_types - ): - continue - - provider_response = ProviderResponse( - provider=provider_configuration.provider.provider, - label=provider_configuration.provider.label, - description=provider_configuration.provider.description, - icon_small=provider_configuration.provider.icon_small, - icon_large=provider_configuration.provider.icon_large, - background=provider_configuration.provider.background, - help=provider_configuration.provider.help, - supported_model_types=provider_configuration.provider.supported_model_types, - configurate_methods=provider_configuration.provider.configurate_methods, - provider_credential_schema=provider_configuration.provider.provider_credential_schema, - model_credential_schema=provider_configuration.provider.model_credential_schema, - preferred_provider_type=ProviderType.value_of("custom"), - custom_configuration=CustomConfigurationResponse( - status=CustomConfigurationStatus.ACTIVE - if provider_configuration.is_custom_configuration_available() - else CustomConfigurationStatus.NO_CONFIGURE - ), - system_configuration=SystemConfigurationResponse(enabled=False), - ) - - provider_responses.append(provider_response) - - return provider_responses - - def get_models_by_model_type( - self, model_type: str - ) -> List[ProviderWithModelsResponse]: - """ - get models by model type. - - :param model_type: model type - :return: - """ - # 合并两个字典的键 - provider = set( - self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys() - ) - provider.update( - self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys() - ) - # Get all provider configurations of the current workspace - provider_configurations = ( - self.provider_manager.provider_manager.get_configurations(provider=provider) - ) - - # Get provider available models - models = provider_configurations.get_models( - model_type=ModelType.value_of(model_type) - ) - - # Group models by provider - provider_models = {} - for model in models: - if model.provider.provider not in provider_models: - provider_models[model.provider.provider] = [] - - if model.deprecated: - continue - - provider_models[model.provider.provider].append(model) - - # convert to ProviderWithModelsResponse list - providers_with_models: list[ProviderWithModelsResponse] = [] - for provider, models in provider_models.items(): - if not models: - continue - - first_model = models[0] - - has_active_models = any( - [model.status == ModelStatus.ACTIVE for model in models] - ) - - providers_with_models.append( - ProviderWithModelsResponse( - provider=provider, - label=first_model.provider.label, - icon_small=first_model.provider.icon_small, - icon_large=first_model.provider.icon_large, - status=CustomConfigurationStatus.ACTIVE - if has_active_models - else CustomConfigurationStatus.NO_CONFIGURE, - models=[ - ModelResponse( - model=model.model, - label=model.label, - model_type=model.model_type, - features=model.features, - fetch_from=model.fetch_from, - model_properties=model.model_properties, - status=model.status, - ) - for model in models - ], - ) - ) - - return providers_with_models - @classmethod @abstractmethod def from_config(cls, cfg=None): diff --git a/model-providers/model_providers/core/bootstrap/providers_wapper.py b/model-providers/model_providers/core/bootstrap/providers_wapper.py new file mode 100644 index 00000000..8a2af953 --- /dev/null +++ b/model-providers/model_providers/core/bootstrap/providers_wapper.py @@ -0,0 +1,153 @@ +from typing import Optional, List + + +from model_providers.bootstrap_web.entities.model_provider_entities import ( + CustomConfigurationResponse, + CustomConfigurationStatus, + ModelResponse, + ProviderResponse, + ProviderWithModelsResponse, + SystemConfigurationResponse, +) + +from model_providers.core.entities.model_entities import ModelStatus +from model_providers.core.entities.provider_entities import ProviderType + +from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.provider_manager import ProviderManager + + +class ProvidersWrapper: + def __init__(self, provider_manager: ProviderManager): + self.provider_manager = provider_manager + + def get_provider_list( + self, model_type: Optional[str] = None + ) -> List[ProviderResponse]: + """ + get provider list. + + :param model_type: model type + :return: + """ + # 合并两个字典的键 + provider = set( + self.provider_manager.provider_name_to_provider_records_dict.keys() + ) + provider.update( + self.provider_manager.provider_name_to_provider_model_records_dict.keys() + ) + # Get all provider configurations of the current workspace + provider_configurations = ( + self.provider_manager.get_configurations(provider=provider) + ) + + provider_responses = [] + for provider_configuration in provider_configurations.values(): + if model_type: + model_type_entity = ModelType.value_of(model_type) + if ( + model_type_entity + not in provider_configuration.provider.supported_model_types + ): + continue + + provider_response = ProviderResponse( + provider=provider_configuration.provider.provider, + label=provider_configuration.provider.label, + description=provider_configuration.provider.description, + icon_small=provider_configuration.provider.icon_small, + icon_large=provider_configuration.provider.icon_large, + background=provider_configuration.provider.background, + help=provider_configuration.provider.help, + supported_model_types=provider_configuration.provider.supported_model_types, + configurate_methods=provider_configuration.provider.configurate_methods, + provider_credential_schema=provider_configuration.provider.provider_credential_schema, + model_credential_schema=provider_configuration.provider.model_credential_schema, + preferred_provider_type=ProviderType.value_of("custom"), + custom_configuration=CustomConfigurationResponse( + status=CustomConfigurationStatus.ACTIVE + if provider_configuration.is_custom_configuration_available() + else CustomConfigurationStatus.NO_CONFIGURE + ), + system_configuration=SystemConfigurationResponse(enabled=False), + ) + + provider_responses.append(provider_response) + + return provider_responses + + def get_models_by_model_type( + self, model_type: str + ) -> List[ProviderWithModelsResponse]: + """ + get models by model type. + + :param model_type: model type + :return: + """ + # 合并两个字典的键 + provider = set( + self.provider_manager.provider_name_to_provider_records_dict.keys() + ) + provider.update( + self.provider_manager.provider_name_to_provider_model_records_dict.keys() + ) + # Get all provider configurations of the current workspace + provider_configurations = ( + self.provider_manager.get_configurations(provider=provider) + ) + + # Get provider available models + models = provider_configurations.get_models( + model_type=ModelType.value_of(model_type) + ) + + # Group models by provider + provider_models = {} + for model in models: + if model.provider.provider not in provider_models: + provider_models[model.provider.provider] = [] + + if model.deprecated: + continue + + provider_models[model.provider.provider].append(model) + + # convert to ProviderWithModelsResponse list + providers_with_models: list[ProviderWithModelsResponse] = [] + for provider, models in provider_models.items(): + if not models: + continue + + first_model = models[0] + + has_active_models = any( + [model.status == ModelStatus.ACTIVE for model in models] + ) + + providers_with_models.append( + ProviderWithModelsResponse( + provider=provider, + label=first_model.provider.label, + icon_small=first_model.provider.icon_small, + icon_large=first_model.provider.icon_large, + status=CustomConfigurationStatus.ACTIVE + if has_active_models + else CustomConfigurationStatus.NO_CONFIGURE, + models=[ + ModelResponse( + model=model.model, + label=model.label, + model_type=model.model_type, + features=model.features, + fetch_from=model.fetch_from, + model_properties=model.model_properties, + status=model.status, + ) + for model in models + ], + ) + ) + + return providers_with_models