diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 60c550ee..bee88fbc 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -57,7 +57,7 @@ class _FaissPool(CachePool): ) -> FAISS: embeddings = EmbeddingsFunAdapter(embed_model) doc = Document(page_content="init", metadata={}) - vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") + vector_store = FAISS.from_documents([doc], embeddings, distance_strategy="METRIC_INNER_PRODUCT") ids = list(vector_store.docstore._dict.keys()) vector_store.delete(ids) return vector_store @@ -94,7 +94,7 @@ class KBFaissPool(_FaissPool): if os.path.isfile(os.path.join(vs_path, "index.faiss")): embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model) - vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") + vector_store = FAISS.load_local(vs_path, embeddings, distance_strategy="METRIC_INNER_PRODUCT") elif create: # create an empty vector store if not os.path.exists(vs_path): diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 1eb37e69..939f411f 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import httpx +import requests from fastchat.conversation import Conversation from httpx_sse import EventSource @@ -44,7 +45,7 @@ class ChatGLMWorker(ApiModelWorker): def __init__( self, *, - model_names: List[str] = ["zhipu-api"], + model_names: List[str] = ("zhipu-api",), controller_addr: str = None, worker_addr: str = None, version: Literal["glm-4"] = "glm-4", @@ -87,28 +88,45 @@ class ChatGLMWorker(ApiModelWorker): def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: + embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL + params.load_config(self.model_names[0]) - token = generate_token(params.api_key, 60) - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {token}" - } i = 0 batch_size = 1 result = [] while i < len(params.texts): + token = generate_token(params.api_key, 60) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token}" + } data = { - "model": params.embed_model or self.DEFAULT_EMBED_MODEL, + "model": embed_model, "input": "".join(params.texts[i: i + batch_size]) } + embedding_data = self.request_embedding_api(headers, data, 1) + if embedding_data: + result.append(embedding_data) + i += batch_size + print(f"请求{embed_model}接口处理第{i}块文本,返回embeddings: \n{embedding_data}") + + return {"code": 200, "data": result} + + # 请求接口,支持重试 + def request_embedding_api(self, headers, data, retry=0): + response = '' + try: url = "https://open.bigmodel.cn/api/paas/v4/embeddings" response = requests.post(url, headers=headers, json=data) ans = response.json() - result.append(ans["data"][0]["embedding"]) - i += batch_size + return ans["data"][0]["embedding"] + except Exception as e: + print(f"request_embedding_api error={e} \nresponse={response}") + if retry > 0: + return self.request_embedding_api(headers, data, retry - 1) + else: + return None - return {"code": 200, "data": result} - def get_embeddings(self, params): print("embedding") print(params)