From 26672ffeda462858d5059aaf2779da6941e58d55 Mon Sep 17 00:00:00 2001 From: hxb Date: Sun, 28 Apr 2024 16:14:31 +0800 Subject: [PATCH] =?UTF-8?q?feature=EF=BC=9A=E9=95=BF=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=BE=AA=E7=8E=AF=E5=A4=84=E7=90=86texts=E6=97=B6=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=87=8D=E8=AF=95=E6=9E=81=E8=87=B4=EF=BC=8C?= =?UTF-8?q?=E9=99=8D=E4=BD=8E=E5=8D=95=E7=89=87=E6=96=87=E6=A1=A3=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5=E5=AF=BC=E8=87=B4=E6=95=B4=E4=B8=AA=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=90=91=E9=87=8F=E5=8C=96=E5=A4=B1=E8=B4=A5=E7=9A=84=E6=A6=82?= =?UTF-8?q?=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/model_workers/zhipu.py | 40 +++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 1eb37e69..939f411f 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import httpx +import requests from fastchat.conversation import Conversation from httpx_sse import EventSource @@ -44,7 +45,7 @@ class ChatGLMWorker(ApiModelWorker): def __init__( self, *, - model_names: List[str] = ["zhipu-api"], + model_names: List[str] = ("zhipu-api",), controller_addr: str = None, worker_addr: str = None, version: Literal["glm-4"] = "glm-4", @@ -87,28 +88,45 @@ class ChatGLMWorker(ApiModelWorker): def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL + params.load_config(self.model_names[0]) - token = generate_token(params.api_key, 60) - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {token}" - } i = 0 batch_size = 1 result = [] while i < len(params.texts): + token = generate_token(params.api_key, 60) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}" + } data = { - "model": params.embed_model or self.DEFAULT_EMBED_MODEL, + "model": embed_model, "input": "".join(params.texts[i: i + batch_size]) } + embedding_data = self.request_embedding_api(headers, data, 1) + if embedding_data: + result.append(embedding_data) + i += batch_size + print(f"请求{embed_model}接口处理第{i}块文本,返回embeddings: \n{embedding_data}") + + return {"code": 200, "data": result} + + # 请求接口,支持重试 + def request_embedding_api(self, headers, data, retry=0): + response = '' + try: url = "https://open.bigmodel.cn/api/paas/v4/embeddings" response = requests.post(url, headers=headers, json=data) ans = response.json() - result.append(ans["data"][0]["embedding"]) - i += batch_size + return ans["data"][0]["embedding"] + except Exception as e: + print(f"request_embedding_api error={e} \nresponse={response}") + if retry > 0: + return self.request_embedding_api(headers, data, retry - 1) + else: + return None - return {"code": 200, "data": result} - def get_embeddings(self, params): print("embedding") print(params)