diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 1eb37e69..939f411f 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import httpx +import requests from fastchat.conversation import Conversation from httpx_sse import EventSource @@ -44,7 +45,7 @@ class ChatGLMWorker(ApiModelWorker): def __init__( self, *, - model_names: List[str] = ["zhipu-api"], + model_names: List[str] = ("zhipu-api",), controller_addr: str = None, worker_addr: str = None, version: Literal["glm-4"] = "glm-4", @@ -87,28 +88,45 @@ class ChatGLMWorker(ApiModelWorker): def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL + params.load_config(self.model_names[0]) - token = generate_token(params.api_key, 60) - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {token}" - } i = 0 batch_size = 1 result = [] while i < len(params.texts): + token = generate_token(params.api_key, 60) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}" + } data = { - "model": params.embed_model or self.DEFAULT_EMBED_MODEL, + "model": embed_model, "input": "".join(params.texts[i: i + batch_size]) } + embedding_data = self.request_embedding_api(headers, data, 1) + if embedding_data: + result.append(embedding_data) + i += batch_size + print(f"请求{embed_model}接口处理第{i}块文本,返回embeddings: \n{embedding_data}") + + return {"code": 200, "data": result} + + # 请求接口,支持重试 + def request_embedding_api(self, headers, data, retry=0): + response = '' + try: url = "https://open.bigmodel.cn/api/paas/v4/embeddings" response = requests.post(url, headers=headers, json=data) ans = response.json() - result.append(ans["data"][0]["embedding"]) - i += batch_size + return ans["data"][0]["embedding"] + except Exception as e: + print(f"request_embedding_api error={e} \nresponse={response}") + if retry > 0: + return self.request_embedding_api(headers, data, retry - 1) + else: + return None - return {"code": 200, "data": result} - def get_embeddings(self, params): print("embedding") print(params)