diff --git a/configs/model_config.py.example b/configs/model_config.py.example index b2256d3e..25763d67 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -105,7 +105,8 @@ llm_model_dict = { }, # 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf "qianfan-api": { - "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" + "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见文档模型支持列表中千帆部分。 + "version_url": "", # 可以不填写version,直接填写在千帆申请模型发布的API地址 "api_base_url": "http://127.0.0.1:8888/v1", "api_key": "", "secret_key": "", diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index cbdd9965..8a593a7e 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -9,10 +9,39 @@ from server.utils import get_model_worker_config from typing import List, Literal, Dict -# TODO: support all qianfan models MODEL_VERSIONS = { "ernie-bot": "completions", "ernie-bot-turbo": "eb-instant", + "bloomz-7b": "bloomz_7b1", + "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", + "llama2-7b-chat": "llama_2_7b", + "llama2-13b-chat": "llama_2_13b", + "llama2-70b-chat": "llama_2_70b", + "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b", + "chatglm2-6b-32k": "chatglm2_6b_32k", + "aquilachat-7b": "aquilachat_7b", + # "linly-llama2-ch-7b": "", # 暂未发布 + # "linly-llama2-ch-13b": "", # 暂未发布 + # "chatglm2-6b": "", # 暂未发布 + # "chatglm2-6b-int4": "", # 暂未发布 + # "falcon-7b": "", # 暂未发布 + # "falcon-180b-chat": "", # 暂未发布 + # "falcon-40b": "", # 暂未发布 + # "rwkv4-world": "", # 暂未发布 + # "rwkv5-world": "", # 暂未发布 + # "rwkv4-pile-14b": "", # 暂未发布 + # "rwkv4-raven-14b": "", # 暂未发布 + # "open-llama-7b": "", # 暂未发布 + # "dolly-12b": "", # 暂未发布 + # "mpt-7b-instruct": "", # 暂未发布 + # "mpt-30b-instruct": "", # 暂未发布 + # "OA-Pythia-12B-SFT-4": "", # 暂未发布 + # "xverse-13b": "", # 暂未发布 + + # # 以下为企业测试,需要单独申请 + # "flan-ul2": "", + # "Cerebras-GPT-6.7B": "" + # "Pythia-6.9B": "" } @@ -40,12 +69,13 @@ def request_qianfan_api( '/{model_version}?access_token={access_token}' config = get_model_worker_config(model_name) version = version or config.get("version") + version_url = config.get("version_url") access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key")) if not access_token: raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?") url = BASE_URL.format( - model_version=MODEL_VERSIONS[version], + model_version=version_url or MODEL_VERSIONS[version], access_token=access_token, ) payload = {