From 1dc069fa9c38e23ffd4e08f278f6bfeb5e32274f Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Thu, 7 Mar 2024 08:31:47 +0800 Subject: [PATCH] =?UTF-8?q?model=5Fconfig=20=E4=B8=AD=E8=A1=A5=E5=85=85=20?= =?UTF-8?q?oneapi=20=E9=BB=98=E8=AE=A4=E5=9C=A8=E7=BA=BF=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=EF=BC=9B/v1/models=20=E6=8E=A5=E5=8F=A3=E6=94=AF=E6=8C=81=20on?= =?UTF-8?q?eapi=20=E5=B9=B3=E5=8F=B0=EF=BC=8C=E7=BB=9F=E4=B8=80=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E6=A8=A1=E5=9E=8B=E5=88=97=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py.example | 49 ++++++++++++++++++++---------- server/api_server/openai_routes.py | 25 ++++++++------- 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index bc523e33..89ad80cc 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -118,22 +118,39 @@ MODEL_PLATFORMS = [ ], }, - # { - # "platform_name": "oneapi", - # "platform_type": "oneapi", - # "api_base_url": "http://127.0.0.1:3000/v1", - # "api_key": "", - # "api_concurrencies": 5, - # "llm_models": [ - # "qwen-turbo", - # "qwen-plus", - # "chatglm_turbo", - # "chatglm_std", - # ], - # "embed_models": [], - # "image_models": [], - # "multimodal_models": [], - # }, + { + "platform_name": "oneapi", + "platform_type": "oneapi", + "api_base_url": "http://127.0.0.1:3000/v1", + "api_key": "sk-", + "api_concurrencies": 5, + "llm_models": [ + # 智谱 API + "chatglm_pro", + "chatglm_turbo", + "chatglm_std", + "chatglm_lite", + # 千问 API + "qwen-turbo", + "qwen-plus", + "qwen-max", + "qwen-max-longcontext", + # 千帆 API + "ERNIE-Bot", + "ERNIE-Bot-turbo", + "ERNIE-Bot-4", + # 星火 API + "SparkDesk", + ], + "embed_models": [ + # 千问 API + "text-embedding-v1", + # 千帆 API + "Embedding-V1", + ], + "image_models": [], + "multimodal_models": [], + }, # { # "platform_name": "loom", diff --git a/server/api_server/openai_routes.py b/server/api_server/openai_routes.py index beb619b6..69d663cb 100644 --- a/server/api_server/openai_routes.py +++ b/server/api_server/openai_routes.py @@ -65,30 +65,29 @@ async def openai_request(method, body): @openai_router.get("/models") -async def list_models() -> Dict: +async def list_models() -> List: ''' 整合所有平台的模型列表。 - 由于 openai sdk 不支持重名模型,对于重名模型,只返回其中响应速度最快的一个。在请求其它接口时会自动按照模型忙闲状态进行调度。 ''' - async def task(name: str): + async def task(name: str, config: Dict): try: client = get_OpenAIClient(name, is_async=True) models = await client.models.list() - models = models.dict(exclude={"data":..., "object":...}) - for x in models: - models[x]["platform_name"] = name - return models + if config.get("platform_type") == "xinference": + models = models.dict(exclude={"data":..., "object":...}) + for x in models: + models[x]["platform_name"] = name + return [{**v, "id": k} for k, v in models.items()] + elif config.get("platform_type") == "oneapi": + return [{**x.dict(), "platform_name": name} for x in models.data] except Exception: logger.error(f"failed request to platform: {name}", exc_info=True) return {} - result = {} - tasks = [asyncio.create_task(task(name)) for name in get_config_platforms()] + result = [] + tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()] for t in asyncio.as_completed(tasks): - for n, v in (await t).items(): - if n not in result: - result[n] = v - + result += (await t) return result