Merge pull request #3744 from liunux4odoo/fix

/v1/models 接口返回值由 List[Model] 改为 {'data': List[Model]},兼容最新版 xinference
This commit is contained in:
liunux4odoo 2024-04-14 21:21:53 +08:00 committed by GitHub
commit 4ce7ce0709
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -100,7 +100,7 @@ async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[],
@openai_router.get("/models") @openai_router.get("/models")
async def list_models() -> List: async def list_models() -> Dict:
''' '''
整合所有平台的模型列表 整合所有平台的模型列表
''' '''
@ -108,23 +108,17 @@ async def list_models() -> List:
try: try:
client = get_OpenAIClient(name, is_async=True) client = get_OpenAIClient(name, is_async=True)
models = await client.models.list() models = await client.models.list()
if config.get("platform_type") == "xinference": return [{**x.model_dump(), "platform_name": name} for x in models.data]
models = models.model_dump(exclude={"data":..., "object":...})
for x in models:
models[x]["platform_name"] = name
return [{**v, "id": k} for k, v in models.items()]
elif config.get("platform_type") == "oneapi":
return [{**x.model_dump(), "platform_name": name} for x in models.data]
except Exception: except Exception:
logger.error(f"failed request to platform: {name}", exc_info=True) logger.error(f"failed request to platform: {name}", exc_info=True)
return {} return []
result = [] result = []
tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()] tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()]
for t in asyncio.as_completed(tasks): for t in asyncio.as_completed(tasks):
result += (await t) result += (await t)
return result return {"object": "list", "data": result}
@openai_router.post("/chat/completions") @openai_router.post("/chat/completions")