model_config 中补充 oneapi 默认在线模型;/v1/models 接口支持 oneapi 平台,统一返回模型列表

This commit is contained in:
liunux4odoo 2024-03-07 08:31:47 +08:00
parent 82dfcd97e6
commit 1dc069fa9c
2 changed files with 45 additions and 29 deletions

View File

@ -118,22 +118,39 @@ MODEL_PLATFORMS = [
], ],
}, },
# { {
# "platform_name": "oneapi", "platform_name": "oneapi",
# "platform_type": "oneapi", "platform_type": "oneapi",
# "api_base_url": "http://127.0.0.1:3000/v1", "api_base_url": "http://127.0.0.1:3000/v1",
# "api_key": "", "api_key": "sk-",
# "api_concurrencies": 5, "api_concurrencies": 5,
# "llm_models": [ "llm_models": [
# "qwen-turbo", # 智谱 API
# "qwen-plus", "chatglm_pro",
# "chatglm_turbo", "chatglm_turbo",
# "chatglm_std", "chatglm_std",
# ], "chatglm_lite",
# "embed_models": [], # 千问 API
# "image_models": [], "qwen-turbo",
# "multimodal_models": [], "qwen-plus",
# }, "qwen-max",
"qwen-max-longcontext",
# 千帆 API
"ERNIE-Bot",
"ERNIE-Bot-turbo",
"ERNIE-Bot-4",
# 星火 API
"SparkDesk",
],
"embed_models": [
# 千问 API
"text-embedding-v1",
# 千帆 API
"Embedding-V1",
],
"image_models": [],
"multimodal_models": [],
},
# { # {
# "platform_name": "loom", # "platform_name": "loom",

View File

@ -65,30 +65,29 @@ async def openai_request(method, body):
@openai_router.get("/models") @openai_router.get("/models")
async def list_models() -> Dict: async def list_models() -> List:
''' '''
整合所有平台的模型列表 整合所有平台的模型列表
由于 openai sdk 不支持重名模型对于重名模型只返回其中响应速度最快的一个在请求其它接口时会自动按照模型忙闲状态进行调度
''' '''
async def task(name: str): async def task(name: str, config: Dict):
try: try:
client = get_OpenAIClient(name, is_async=True) client = get_OpenAIClient(name, is_async=True)
models = await client.models.list() models = await client.models.list()
models = models.dict(exclude={"data":..., "object":...}) if config.get("platform_type") == "xinference":
for x in models: models = models.dict(exclude={"data":..., "object":...})
models[x]["platform_name"] = name for x in models:
return models models[x]["platform_name"] = name
return [{**v, "id": k} for k, v in models.items()]
elif config.get("platform_type") == "oneapi":
return [{**x.dict(), "platform_name": name} for x in models.data]
except Exception: except Exception:
logger.error(f"failed request to platform: {name}", exc_info=True) logger.error(f"failed request to platform: {name}", exc_info=True)
return {} return {}
result = {} result = []
tasks = [asyncio.create_task(task(name)) for name in get_config_platforms()] tasks = [asyncio.create_task(task(name, config)) for name, config in get_config_platforms().items()]
for t in asyncio.as_completed(tasks): for t in asyncio.as_completed(tasks):
for n, v in (await t).items(): result += (await t)
if n not in result:
result[n] = v
return result return result