diff --git a/chatchat-server/chatchat/server/api_server/openai_routes.py b/chatchat-server/chatchat/server/api_server/openai_routes.py index c9e616e2..bde5b04c 100644 --- a/chatchat-server/chatchat/server/api_server/openai_routes.py +++ b/chatchat-server/chatchat/server/api_server/openai_routes.py @@ -100,7 +100,7 @@ async def openai_request(method, body, extra_json: Dict={}, header: Iterable=[], @openai_router.get("/models") -async def list_models() -> List: +async def list_models() -> Dict: ''' 整合所有平台的模型列表。 ''' @@ -108,23 +108,17 @@ async def list_models() -> List: try: client = get_OpenAIClient(name, is_async=True) models = await client.models.list() - if config.get("platform_type") == "xinference": - models = models.model_dump(exclude={"data":..., "object":...}) - for x in models: - models[x]["platform_name"] = name - return [{**v, "id": k} for k, v in models.items()] - elif config.get("platform_type") == "oneapi": - return [{**x.model_dump(), "platform_name": name} for x in models.data] + return [{**x.model_dump(), "platform_name": name} for x in models.data] except Exception: logger.error(f"failed request to platform: {name}", exc_info=True) - return {} + return [] result = [] tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()] for t in asyncio.as_completed(tasks): result += (await t) - return result + return {"object": "list", "data": result} @openai_router.post("/chat/completions")