/v1/models 接口返回值由 List[Model] 改为 {'data': List[Model]},兼容最新版 xinference

This commit is contained in:
liunux4odoo 2024-04-14 21:20:45 +08:00
parent d7dc01e9b1
commit f1e5b2c9aa

View File

@ -100,7 +100,7 @@ async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[],
@openai_router.get("/models")
async def list_models() -> List:
async def list_models() -> Dict:
'''
整合所有平台的模型列表
'''
@ -108,23 +108,17 @@ async def list_models() -> List:
try:
client = get_OpenAIClient(name, is_async=True)
models = await client.models.list()
if config.get("platform_type") == "xinference":
models = models.model_dump(exclude={"data":..., "object":...})
for x in models:
models[x]["platform_name"] = name
return [{**v, "id": k} for k, v in models.items()]
elif config.get("platform_type") == "oneapi":
return [{**x.model_dump(), "platform_name": name} for x in models.data]
return [{**x.model_dump(), "platform_name": name} for x in models.data]
except Exception:
logger.error(f"failed request to platform: {name}", exc_info=True)
return {}
return []
result = []
tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()]
for t in asyncio.as_completed(tasks):
result += (await t)
return result
return {"object": "list", "data": result}
@openai_router.post("/chat/completions")