mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
commit
cbc28d7296
@ -57,7 +57,7 @@ class _FaissPool(CachePool):
|
||||
) -> FAISS:
|
||||
embeddings = EmbeddingsFunAdapter(embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
vector_store = FAISS.from_documents([doc], embeddings, distance_strategy="METRIC_INNER_PRODUCT")
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
vector_store.delete(ids)
|
||||
return vector_store
|
||||
@ -94,7 +94,7 @@ class KBFaissPool(_FaissPool):
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, distance_strategy="METRIC_INNER_PRODUCT")
|
||||
elif create:
|
||||
# create an empty vector store
|
||||
if not os.path.exists(vs_path):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from fastchat.conversation import Conversation
|
||||
from httpx_sse import EventSource
|
||||
|
||||
@ -44,7 +45,7 @@ class ChatGLMWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["zhipu-api"],
|
||||
model_names: List[str] = ("zhipu-api",),
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["glm-4"] = "glm-4",
|
||||
@ -87,28 +88,45 @@ class ChatGLMWorker(ApiModelWorker):
|
||||
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
token = generate_token(params.api_key, 60)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
i = 0
|
||||
batch_size = 1
|
||||
result = []
|
||||
while i < len(params.texts):
|
||||
token = generate_token(params.api_key, 60)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
data = {
|
||||
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
|
||||
"model": embed_model,
|
||||
"input": "".join(params.texts[i: i + batch_size])
|
||||
}
|
||||
embedding_data = self.request_embedding_api(headers, data, 1)
|
||||
if embedding_data:
|
||||
result.append(embedding_data)
|
||||
i += batch_size
|
||||
print(f"请求{embed_model}接口处理第{i}块文本,返回embeddings: \n{embedding_data}")
|
||||
|
||||
return {"code": 200, "data": result}
|
||||
|
||||
# 请求接口,支持重试
|
||||
def request_embedding_api(self, headers, data, retry=0):
|
||||
response = ''
|
||||
try:
|
||||
url = "https://open.bigmodel.cn/api/paas/v4/embeddings"
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
ans = response.json()
|
||||
result.append(ans["data"][0]["embedding"])
|
||||
i += batch_size
|
||||
return ans["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
print(f"request_embedding_api error={e} \nresponse={response}")
|
||||
if retry > 0:
|
||||
return self.request_embedding_api(headers, data, retry - 1)
|
||||
else:
|
||||
return None
|
||||
|
||||
return {"code": 200, "data": result}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user