/v1/models 接口返回值由 List[Model] 改为 {'data': List[Model]},兼容最新版 xinference

This commit is contained in:
liunux4odoo 2024-04-14 21:20:45 +08:00
parent d7dc01e9b1
commit f1e5b2c9aa

View File

@ -100,7 +100,7 @@ async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[],
@openai_router.get("/models") @openai_router.get("/models")
async def list_models() -> List: async def list_models() -> Dict:
''' '''
整合所有平台的模型列表 整合所有平台的模型列表
''' '''
@ -108,23 +108,17 @@ async def list_models() -> List:
try: try:
client = get_OpenAIClient(name, is_async=True) client = get_OpenAIClient(name, is_async=True)
models = await client.models.list() models = await client.models.list()
if config.get("platform_type") == "xinference": return [{**x.model_dump(), "platform_name": name} for x in models.data]
models = models.model_dump(exclude={"data":..., "object":...})
for x in models:
models[x]["platform_name"] = name
return [{**v, "id": k} for k, v in models.items()]
elif config.get("platform_type") == "oneapi":
return [{**x.model_dump(), "platform_name": name} for x in models.data]
except Exception: except Exception:
logger.error(f"failed request to platform: {name}", exc_info=True) logger.error(f"failed request to platform: {name}", exc_info=True)
return {} return []
result = [] result = []
tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()] tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()]
for t in asyncio.as_completed(tasks): for t in asyncio.as_completed(tasks):
result += (await t) result += (await t)
return result return {"object": "list", "data": result}
@openai_router.post("/chat/completions") @openai_router.post("/chat/completions")