From 0078cdc72451467c7b1af60cbb948e257059db52 Mon Sep 17 00:00:00 2001 From: hxb Date: Sun, 28 Apr 2024 16:12:23 +0800 Subject: [PATCH 1/2] =?UTF-8?q?bugfix:=20=E4=BD=BF=E7=94=A8=E5=90=91?= =?UTF-8?q?=E9=87=8F=E8=AE=A1=E7=AE=97=E6=96=B9=E5=BC=8FMETRIC=5FINNER=5FP?= =?UTF-8?q?RODUCT=E6=97=B6=E5=90=AF=E7=94=A8normalize=5FL2=E4=BC=9A?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=E5=90=91=E9=87=8F=E5=8C=96=E5=A4=B1=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/knowledge_base/kb_cache/faiss_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 60c550ee..bee88fbc 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -57,7 +57,7 @@ class _FaissPool(CachePool): ) -> FAISS: embeddings = EmbeddingsFunAdapter(embed_model) doc = Document(page_content="init", metadata={}) - vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") + vector_store = FAISS.from_documents([doc], embeddings, distance_strategy="METRIC_INNER_PRODUCT") ids = list(vector_store.docstore._dict.keys()) vector_store.delete(ids) return vector_store @@ -94,7 +94,7 @@ class KBFaissPool(_FaissPool): if os.path.isfile(os.path.join(vs_path, "index.faiss")): embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model) - vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") + vector_store = FAISS.load_local(vs_path, embeddings, distance_strategy="METRIC_INNER_PRODUCT") elif create: # create an empty vector store if not os.path.exists(vs_path): From 26672ffeda462858d5059aaf2779da6941e58d55 Mon Sep 17 00:00:00 2001 From: hxb Date: Sun, 28 Apr 2024 16:14:31 +0800 Subject: [PATCH 2/2] =?UTF-8?q?feature=EF=BC=9A=E9=95=BF=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=BE=AA=E7=8E=AF=E5=A4=84=E7=90=86texts=E6=97=B6=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=87=8D=E8=AF=95=E6=9E=81=E8=87=B4=EF=BC=8C?= =?UTF-8?q?=E9=99=8D=E4=BD=8E=E5=8D=95=E7=89=87=E6=96=87=E6=A1=A3=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5=E5=AF=BC=E8=87=B4=E6=95=B4=E4=B8=AA=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=90=91=E9=87=8F=E5=8C=96=E5=A4=B1=E8=B4=A5=E7=9A=84=E6=A6=82?= =?UTF-8?q?=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/model_workers/zhipu.py | 40 +++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 1eb37e69..939f411f 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import httpx +import requests from fastchat.conversation import Conversation from httpx_sse import EventSource @@ -44,7 +45,7 @@ class ChatGLMWorker(ApiModelWorker): def __init__( self, *, - model_names: List[str] = ["zhipu-api"], + model_names: List[str] = ("zhipu-api",), controller_addr: str = None, worker_addr: str = None, version: Literal["glm-4"] = "glm-4", @@ -87,28 +88,45 @@ class ChatGLMWorker(ApiModelWorker): def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL + params.load_config(self.model_names[0]) - token = generate_token(params.api_key, 60) - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {token}" - } i = 0 batch_size = 1 result = [] while i < len(params.texts): + token = generate_token(params.api_key, 60) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}" + } data = { - "model": params.embed_model or self.DEFAULT_EMBED_MODEL, + "model": embed_model, "input": "".join(params.texts[i: i + batch_size]) } + embedding_data = self.request_embedding_api(headers, data, 1) + if embedding_data: + result.append(embedding_data) + i += batch_size + print(f"请求{embed_model}接口处理第{i}块文本,返回embeddings: \n{embedding_data}") + + return {"code": 200, "data": result} + + # 请求接口,支持重试 + def request_embedding_api(self, headers, data, retry=0): + response = '' + try: url = "https://open.bigmodel.cn/api/paas/v4/embeddings" response = requests.post(url, headers=headers, json=data) ans = response.json() - result.append(ans["data"][0]["embedding"]) - i += batch_size + return ans["data"][0]["embedding"] + except Exception as e: + print(f"request_embedding_api error={e} \nresponse={response}") + if retry > 0: + return self.request_embedding_api(headers, data, retry - 1) + else: + return None - return {"code": 200, "data": result} - def get_embeddings(self, params): print("embedding") print(params)