From c6d91c3e3def2a8b7632813f96ca99c9aad3ca5a Mon Sep 17 00:00:00 2001 From: qqlww1987 <2274782404@qq.com> Date: Wed, 6 Mar 2024 14:14:06 +0800 Subject: [PATCH] =?UTF-8?q?=E7=99=BE=E5=B7=9DAPI=E7=A4=BA=E4=BE=8B?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 百川API调用示例 --- server/model_workers/baichuan.py | 71 +++++++++++++++----------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py index 75cfad4e..4516e11a 100644 --- a/server/model_workers/baichuan.py +++ b/server/model_workers/baichuan.py @@ -1,7 +1,7 @@ import json import time import hashlib - +import requests from fastchat.conversation import Conversation from server.model_workers.base import * from server.utils import get_httpx_client @@ -32,61 +32,58 @@ class BaiChuanWorker(ApiModelWorker): kwargs.setdefault("context_len", 32768) super().__init__(**kwargs) self.version = version - def do_chat(self, params: ApiChatParams) -> Dict: params.load_config(self.model_names[0]) - url = "https://api.baichuan-ai.com/v1/stream/chat" + url = "https://api.baichuan-ai.com/v1/chat/completions" data = { "model": params.version, "messages": params.messages, - "parameters": {"temperature": params.temperature} + "stream": False, + } - json_data = json.dumps(data) - time_stamp = int(time.time()) - signature = calculate_md5(params.secret_key + json_data + str(time_stamp)) headers = { "Content-Type": "application/json", "Authorization": "Bearer " + params.api_key, - "X-BC-Request-Id": "your requestId", - "X-BC-Timestamp": str(time_stamp), - "X-BC-Signature": signature, - "X-BC-Sign-Algo": "MD5", + } - text = "" - if log_verbose: - logger.info(f'{self.__class__.__name__}:json_data: {json_data}') - logger.info(f'{self.__class__.__name__}:url: {url}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') - - with get_httpx_client() as client: - with client.stream("POST", url, headers=headers, json=data) as response: - for line in response.iter_lines(): - if not line.strip(): - continue - resp = json.loads(line) - if resp["code"] == 0: - text += resp["data"]["messages"][-1]["content"] - yield { - "error_code": resp["code"], - "text": text + response = requests.post(url, headers=headers, json=data) + if response.status_code == 200: + print("请求成功!"+response.text) + result = json.loads(response.text) + textMsg="" + result["choices"][0]["delta"]=result["choices"][0]["message"] + if 'choices' in result: + textMsg += result["choices"][0]["message"]["content"] + data = { + "error_code": response.status_code, + "text": textMsg, + "choices":result["choices"], + "model":result["model"], + "object":result["object"], + "object":result["object"], + "created":result["created"], + "id":result["id"], } - else: - data = { - "error_code": resp["code"], - "text": resp["msg"], + + yield data + + else: + + data = { + "error_code": response.status_code, + "text":response.text, "error": { - "message": resp["msg"], + "message": response.text, "type": "invalid_request_error", "param": None, "code": None, } - } - self.logger.error(f"请求百川 API 时发生错误:{data}") - yield data - + } + self.logger.error(f"请求百川 API 时发生错误:{data}") + yield data def get_embeddings(self, params): print("embedding") print(params)