From 6011bac07fa24c612a6522e0ee6810daecf00ebe Mon Sep 17 00:00:00 2001 From: dengpeng777 Date: Sat, 30 Mar 2024 09:58:18 +0800 Subject: [PATCH 1/2] support zhipu-api embedding --- server/model_workers/zhipu.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 898427c8..b925760d 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -84,7 +84,29 @@ class ChatGLMWorker(ApiModelWorker): # yield {"error_code": 0, "text": text} + def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + params.load_config(self.model_names[0]) + token = generate_token(params.api_key, 60) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}" + } + i = 0 + batch_size = 1 + result = [] + while i < len(params.texts): + data = { + "model": params.embed_model or self.DEFAULT_EMBED_MODEL, + "input": "".join(params.texts[i: i + batch_size]) + } + url = "https://open.bigmodel.cn/api/paas/v4/embeddings" + response = requests.post(url, headers=headers, json=data) + ans = response.json() + result.append(ans["data"][0]["embedding"]) + i += batch_size + return {"code": 200, "data": result} + def get_embeddings(self, params): print("embedding") print(params) From a592b1cd3fd8027b392231f09b4e1d91b3cb0346 Mon Sep 17 00:00:00 2001 From: dengpeng777 Date: Sat, 30 Mar 2024 10:00:55 +0800 Subject: [PATCH 2/2] fix add default embed model --- server/model_workers/zhipu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index b925760d..1eb37e69 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -39,6 +39,8 @@ def generate_token(apikey: str, exp_seconds: int): class ChatGLMWorker(ApiModelWorker): + DEFAULT_EMBED_MODEL = "embedding-2" + def __init__( self, *,