From f0f1dc2537f4c5c56f62c1593304580b3c4b48dc Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Fri, 15 Sep 2023 00:30:18 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=85=E6=9B=B4=E5=A4=9A=E5=8D=83?= =?UTF-8?q?=E5=B8=86=E5=B9=B3=E5=8F=B0=E6=94=AF=E6=8C=81=E7=9A=84=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=EF=BC=9B=E9=99=A4=E4=BA=86=E6=8C=87=E5=AE=9A=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=90=8D=E7=A7=B0=EF=BC=8C=E6=94=AF=E6=8C=81=E7=9B=B4?= =?UTF-8?q?=E6=8E=A5=E6=8C=87=E5=AE=9A=E6=A8=A1=E5=9E=8BAPIURL=EF=BC=8C?= =?UTF-8?q?=E4=BE=BF=E4=BA=8E=E5=A1=AB=E5=86=99=E5=8D=95=E7=8B=AC=E7=94=B3?= =?UTF-8?q?=E8=AF=B7=E7=9A=84=E6=A8=A1=E5=9E=8B=E5=9C=B0=E5=9D=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py.example | 3 ++- server/model_workers/qianfan.py | 34 +++++++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index b2256d3e..25763d67 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -105,7 +105,8 @@ llm_model_dict = { }, # 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf "qianfan-api": { - "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" + "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见文档模型支持列表中千帆部分。 + "version_url": "", # 可以不填写version,直接填写在千帆申请模型发布的API地址 "api_base_url": "http://127.0.0.1:8888/v1", "api_key": "", "secret_key": "", diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index cbdd9965..8a593a7e 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -9,10 +9,39 @@ from server.utils import get_model_worker_config from typing import List, Literal, Dict -# TODO: support all qianfan models MODEL_VERSIONS = { "ernie-bot": "completions", "ernie-bot-turbo": "eb-instant", + "bloomz-7b": "bloomz_7b1", + "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", + "llama2-7b-chat": "llama_2_7b", + "llama2-13b-chat": "llama_2_13b", + "llama2-70b-chat": "llama_2_70b", + "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b", + "chatglm2-6b-32k": "chatglm2_6b_32k", + "aquilachat-7b": "aquilachat_7b", + # "linly-llama2-ch-7b": "", # 暂未发布 + # "linly-llama2-ch-13b": "", # 暂未发布 + # "chatglm2-6b": "", # 暂未发布 + # "chatglm2-6b-int4": "", # 暂未发布 + # "falcon-7b": "", # 暂未发布 + # "falcon-180b-chat": "", # 暂未发布 + # "falcon-40b": "", # 暂未发布 + # "rwkv4-world": "", # 暂未发布 + # "rwkv5-world": "", # 暂未发布 + # "rwkv4-pile-14b": "", # 暂未发布 + # "rwkv4-raven-14b": "", # 暂未发布 + # "open-llama-7b": "", # 暂未发布 + # "dolly-12b": "", # 暂未发布 + # "mpt-7b-instruct": "", # 暂未发布 + # "mpt-30b-instruct": "", # 暂未发布 + # "OA-Pythia-12B-SFT-4": "", # 暂未发布 + # "xverse-13b": "", # 暂未发布 + + # # 以下为企业测试,需要单独申请 + # "flan-ul2": "", + # "Cerebras-GPT-6.7B": "" + # "Pythia-6.9B": "" } @@ -40,12 +69,13 @@ def request_qianfan_api( '/{model_version}?access_token={access_token}' config = get_model_worker_config(model_name) version = version or config.get("version") + version_url = config.get("version_url") access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key")) if not access_token: raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?") url = BASE_URL.format( - model_version=MODEL_VERSIONS[version], + model_version=version_url or MODEL_VERSIONS[version], access_token=access_token, ) payload = {