mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-03-23 02:35:52 +08:00
embeddings模块集成openai plugins插件,使用统一api调用
This commit is contained in:
parent
217bb61448
commit
307b973f26
2
.gitignore
vendored
2
.gitignore
vendored
@ -180,3 +180,5 @@ test.py
|
||||
configs/*.py
|
||||
|
||||
|
||||
/knowledge_base/samples/content/202311-D平台项目工作大纲参数,人员中间库表结构说明V1.1(1).docx
|
||||
/knowledge_base/samples/content/imi_temeplate.txt
|
||||
|
||||
@ -16,7 +16,7 @@ from starlette.responses import RedirectResponse
|
||||
from server.chat.chat import chat
|
||||
from server.chat.completion import completion
|
||||
from server.chat.feedback import chat_feedback
|
||||
from server.embeddings_api import embed_texts_endpoint
|
||||
from server.embeddings.core.embeddings_api import embed_texts_endpoint
|
||||
|
||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||
get_server_configs, get_prompt_template)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from fastapi import Body, File, Form, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
from server.embeddings.adapter import load_temp_adapter_embeddings
|
||||
from server.utils import (wrap_done, get_ChatOpenAI,
|
||||
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||
@ -10,11 +11,11 @@ from typing import AsyncIterable, List, Optional
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def _parse_files_in_thread(
|
||||
files: List[UploadFile],
|
||||
dir: str,
|
||||
@ -26,6 +27,7 @@ def _parse_files_in_thread(
|
||||
通过多线程将上传的文件保存到对应目录内。
|
||||
生成器返回保存结果:[success or error, filename, msg, docs]
|
||||
"""
|
||||
|
||||
def parse_file(file: UploadFile) -> dict:
|
||||
'''
|
||||
保存单个文件。
|
||||
@ -55,6 +57,9 @@ def _parse_files_in_thread(
|
||||
|
||||
|
||||
def upload_temp_docs(
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
prev_id: str = Form(None, description="前知识库ID"),
|
||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
@ -81,7 +86,11 @@ def upload_temp_docs(
|
||||
else:
|
||||
failed_files.append({file: msg})
|
||||
|
||||
with memo_faiss_pool.load_vector_store(id).acquire() as vs:
|
||||
with memo_faiss_pool.load_vector_store(kb_name=id,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
).acquire() as vs:
|
||||
vs.add_documents(documents)
|
||||
return BaseResponse(data={"id": id, "failed_files": failed_files})
|
||||
|
||||
@ -89,7 +98,9 @@ def upload_temp_docs(
|
||||
async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_id: str = Body(..., description="临时知识库ID"),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD,
|
||||
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
|
||||
ge=0, le=2),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
@ -105,7 +116,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
prompt_name: str = Body("default",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
if knowledge_id not in memo_faiss_pool.keys():
|
||||
return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件")
|
||||
@ -127,7 +139,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
embed_func = EmbeddingsFunAdapter()
|
||||
embed_func = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy)
|
||||
embeddings = await embed_func.aembed_query(query)
|
||||
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
|
||||
@ -12,6 +12,9 @@ class KnowledgeBaseModel(Base):
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
|
||||
vs_type = Column(String(50), comment='向量库类型')
|
||||
endpoint_host = Column(String(50), comment='接入点地址')
|
||||
endpoint_host_key = Column(String(50), comment='接入点key')
|
||||
endpoint_host_proxy = Column(String(50), comment='接入点代理地址')
|
||||
embed_model = Column(String(50), comment='嵌入模型名称')
|
||||
file_count = Column(Integer, default=0, comment='文件数量')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
@ -48,6 +48,16 @@ def delete_kb_from_db(session, kb_name):
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def update_kb_endpoint_from_db(session, kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy):
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
||||
if kb:
|
||||
kb.endpoint_host = endpoint_host
|
||||
kb.endpoint_host_key = endpoint_host_key
|
||||
kb.endpoint_host_proxy = endpoint_host_proxy
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def get_kb_detail(session, kb_name: str) -> dict:
|
||||
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
||||
@ -56,6 +66,9 @@ def get_kb_detail(session, kb_name: str) -> dict:
|
||||
"kb_name": kb.kb_name,
|
||||
"kb_info": kb.kb_info,
|
||||
"vs_type": kb.vs_type,
|
||||
"endpoint_host": kb.endpoint_host,
|
||||
"endpoint_host_key": kb.endpoint_host_key,
|
||||
"endpoint_host_proxy": kb.endpoint_host_proxy,
|
||||
"embed_model": kb.embed_model,
|
||||
"file_count": kb.file_count,
|
||||
"create_time": kb.create_time,
|
||||
|
||||
0
server/embeddings/__init__.py
Normal file
0
server/embeddings/__init__.py
Normal file
130
server/embeddings/adapter.py
Normal file
130
server/embeddings/adapter.py
Normal file
@ -0,0 +1,130 @@
|
||||
import numpy as np
|
||||
from typing import List, Union, Dict
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_MODEL, KB_INFO)
|
||||
from server.embeddings.core.embeddings_api import embed_texts, aembed_texts
|
||||
from server.utils import embedding_device
|
||||
|
||||
|
||||
class EmbeddingsFunAdapter(Embeddings):
|
||||
_endpoint_host: str
|
||||
_endpoint_host_key: str
|
||||
_endpoint_host_proxy: str
|
||||
|
||||
def __init__(self,
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
):
|
||||
self._endpoint_host = endpoint_host
|
||||
self._endpoint_host_key = endpoint_host_key
|
||||
self._endpoint_host_proxy = endpoint_host_proxy
|
||||
self.embed_model = embed_model
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = embed_texts(texts=texts,
|
||||
endpoint_host=self._endpoint_host,
|
||||
endpoint_host_key=self._endpoint_host_key,
|
||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||
embed_model=self.embed_model,
|
||||
to_query=False).data
|
||||
return self._normalize(embeddings=embeddings).tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
embeddings = embed_texts(texts=[text],
|
||||
endpoint_host=self._endpoint_host,
|
||||
endpoint_host_key=self._endpoint_host_key,
|
||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||
embed_model=self.embed_model,
|
||||
to_query=True).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = (await aembed_texts(texts=texts,
|
||||
endpoint_host=self._endpoint_host,
|
||||
endpoint_host_key=self._endpoint_host_key,
|
||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||
embed_model=self.embed_model,
|
||||
to_query=False)).data
|
||||
return self._normalize(embeddings=embeddings).tolist()
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
embeddings = (await aembed_texts(texts=[text],
|
||||
endpoint_host=self._endpoint_host,
|
||||
endpoint_host_key=self._endpoint_host_key,
|
||||
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||
embed_model=self.embed_model,
|
||||
to_query=True)).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
@staticmethod
|
||||
def _normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||
'''
|
||||
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
||||
#TODO 此处内容处理错误
|
||||
'''
|
||||
norm = np.linalg.norm(embeddings, axis=1)
|
||||
norm = np.reshape(norm, (norm.shape[0], 1))
|
||||
norm = np.tile(norm, (1, len(embeddings[0])))
|
||||
return np.divide(embeddings, norm)
|
||||
|
||||
|
||||
def load_kb_adapter_embeddings(
|
||||
kb_name: str,
|
||||
embed_device: str = embedding_device(),
|
||||
default_embed_model: str = EMBEDDING_MODEL,
|
||||
) -> "EmbeddingsFunAdapter":
|
||||
"""
|
||||
加载知识库配置的Embeddings模型
|
||||
本地模型最终会通过load_embeddings加载
|
||||
在线模型会在适配器中直接返回
|
||||
:param kb_name:
|
||||
:param embed_device:
|
||||
:param default_embed_model:
|
||||
:return:
|
||||
"""
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
|
||||
kb_detail = get_kb_detail(kb_name)
|
||||
embed_model = kb_detail.get("embed_model", default_embed_model)
|
||||
endpoint_host = kb_detail.get("endpoint_host", None)
|
||||
endpoint_host_key = kb_detail.get("endpoint_host_key", None)
|
||||
endpoint_host_proxy = kb_detail.get("endpoint_host_proxy", None)
|
||||
|
||||
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=embed_model)
|
||||
|
||||
|
||||
def load_temp_adapter_embeddings(
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_device: str = embedding_device(),
|
||||
default_embed_model: str = EMBEDDING_MODEL,
|
||||
) -> "EmbeddingsFunAdapter":
|
||||
"""
|
||||
加载临时的Embeddings模型
|
||||
本地模型最终会通过load_embeddings加载
|
||||
在线模型会在适配器中直接返回
|
||||
:param endpoint_host:
|
||||
:param endpoint_host_key:
|
||||
:param endpoint_host_proxy:
|
||||
:param embed_device:
|
||||
:param default_embed_model:
|
||||
:return:
|
||||
"""
|
||||
|
||||
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=default_embed_model)
|
||||
0
server/embeddings/core/__init__.py
Normal file
0
server/embeddings/core/__init__.py
Normal file
94
server/embeddings/core/embeddings_api.py
Normal file
94
server/embeddings/core/embeddings_api.py
Normal file
@ -0,0 +1,94 @@
|
||||
from langchain.docstore.document import Document
|
||||
from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE
|
||||
from server.utils import BaseResponse, list_embed_models, list_online_embed_models
|
||||
from fastapi import Body
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def embed_texts(
|
||||
texts: List[str],
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
TODO: 也许需要加入缓存机制,减少 token 消耗
|
||||
'''
|
||||
try:
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||
|
||||
# 使用在线API
|
||||
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy):
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings(model=embed_model,
|
||||
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||
chunk_size=CHUNK_SIZE)
|
||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||
|
||||
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
async def aembed_texts(
|
||||
texts: List[str],
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
try:
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
||||
|
||||
# 使用在线API
|
||||
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy):
|
||||
return await run_in_threadpool(embed_texts,
|
||||
texts=texts,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=embed_model,
|
||||
to_query=to_query)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
def embed_texts_endpoint(
|
||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||
endpoint_host: str = Body(False, description="接入点地址"),
|
||||
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
|
||||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
接入api,对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
return embed_texts(texts=texts,
|
||||
endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=embed_model, to_query=to_query)
|
||||
@ -1,96 +0,0 @@
|
||||
from langchain.docstore.document import Document
|
||||
from configs import EMBEDDING_MODEL, logger
|
||||
# from server.model_workers.base import ApiEmbeddingsParams
|
||||
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
|
||||
from fastapi import Body
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from typing import Dict, List
|
||||
|
||||
online_embed_models = list_online_embed_models()
|
||||
|
||||
|
||||
def embed_texts(
|
||||
texts: List[str],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
try:
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||
|
||||
if embed_model in list_online_embed_models(): # 使用在线API
|
||||
config = get_model_worker_config(embed_model)
|
||||
worker_class = config.get("worker_class")
|
||||
embed_model = config.get("embed_model")
|
||||
worker = worker_class()
|
||||
if worker_class.can_embedding():
|
||||
# params = ApiEmbeddingsParams(texts=texts, to_query=to_query)
|
||||
resp = worker.do_embeddings(None)
|
||||
return BaseResponse(**resp)
|
||||
|
||||
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
async def aembed_texts(
|
||||
texts: List[str],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
try:
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
||||
|
||||
if embed_model in list_online_embed_models(): # 使用在线API
|
||||
return await run_in_threadpool(embed_texts,
|
||||
texts=texts,
|
||||
embed_model=embed_model,
|
||||
to_query=to_query)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
def embed_texts_endpoint(
|
||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||
embed_model: str = Body(EMBEDDING_MODEL,
|
||||
description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
|
||||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
|
||||
|
||||
|
||||
def embed_documents(
|
||||
docs: List[Document],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数
|
||||
"""
|
||||
texts = [x.page_content for x in docs]
|
||||
metadatas = [x.metadata for x in docs]
|
||||
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
|
||||
if embeddings is not None:
|
||||
return {
|
||||
"texts": texts,
|
||||
"embeddings": embeddings,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
@ -3,7 +3,7 @@ from langchain.vectorstores.faiss import FAISS
|
||||
import threading
|
||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
||||
logger, log_verbose)
|
||||
from server.utils import embedding_device, get_model_path, list_online_embed_models
|
||||
from server.utils import embedding_device, get_model_path
|
||||
from contextlib import contextmanager
|
||||
from collections import OrderedDict
|
||||
from typing import List, Any, Union, Tuple
|
||||
@ -98,26 +98,16 @@ class CachePool:
|
||||
else:
|
||||
return cache
|
||||
|
||||
def load_kb_embeddings(
|
||||
self,
|
||||
kb_name: str,
|
||||
embed_device: str = embedding_device(),
|
||||
default_embed_model: str = EMBEDDING_MODEL,
|
||||
) -> Embeddings:
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
|
||||
kb_detail = get_kb_detail(kb_name)
|
||||
embed_model = kb_detail.get("embed_model", default_embed_model)
|
||||
|
||||
if embed_model in list_online_embed_models():
|
||||
return EmbeddingsFunAdapter(embed_model)
|
||||
else:
|
||||
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
|
||||
|
||||
|
||||
class EmbeddingsPool(CachePool):
|
||||
|
||||
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
||||
"""
|
||||
本地Embeddings模型加载
|
||||
:param model:
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
self.atomic.acquire()
|
||||
model = model or EMBEDDING_MODEL
|
||||
device = embedding_device()
|
||||
@ -127,12 +117,7 @@ class EmbeddingsPool(CachePool):
|
||||
self.set(key, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings(model=model,
|
||||
openai_api_key=get_model_path(model),
|
||||
chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
if 'bge-' in model:
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
if 'zh' in model:
|
||||
# for chinese model
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
||||
from server.embeddings.adapter import load_kb_adapter_embeddings, load_temp_adapter_embeddings
|
||||
from server.knowledge_base.kb_cache.base import *
|
||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
from server.utils import load_local_embeddings
|
||||
# from server.utils import load_local_embeddings
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
@ -52,16 +52,40 @@ class ThreadSafeFaiss(ThreadSafeObject):
|
||||
class _FaissPool(CachePool):
|
||||
def new_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> FAISS:
|
||||
embeddings = EmbeddingsFunAdapter(embed_model)
|
||||
|
||||
# create an empty vector store
|
||||
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
|
||||
embed_device=embed_device, default_embed_model=embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
vector_store.delete(ids)
|
||||
return vector_store
|
||||
|
||||
def new_temp_vector_store(
|
||||
self,
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> FAISS:
|
||||
|
||||
# create an empty vector store
|
||||
embeddings = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_device=embed_device, default_embed_model=embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
vector_store.delete(ids)
|
||||
return vector_store
|
||||
|
||||
def save_vector_store(self, kb_name: str, path: str = None):
|
||||
if cache := self.get(kb_name):
|
||||
return cache.save(path)
|
||||
@ -84,6 +108,7 @@ class KBFaissPool(_FaissPool):
|
||||
self.atomic.acquire()
|
||||
vector_name = vector_name or embed_model
|
||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||
try:
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
||||
self.set((kb_name, vector_name), item)
|
||||
@ -93,13 +118,15 @@ class KBFaissPool(_FaissPool):
|
||||
vs_path = get_vs_path(kb_name, vector_name)
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
|
||||
embed_device=embed_device, default_embed_model=embed_model)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
elif create:
|
||||
# create an empty vector store
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||||
vector_store = self.new_vector_store(kb_name=kb_name,
|
||||
embed_model=embed_model, embed_device=embed_device)
|
||||
vector_store.save_local(vs_path)
|
||||
else:
|
||||
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
||||
@ -107,13 +134,23 @@ class KBFaissPool(_FaissPool):
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
except Exception as e:
|
||||
self.atomic.release()
|
||||
logger.error(e)
|
||||
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
|
||||
return self.get((kb_name, vector_name))
|
||||
|
||||
|
||||
class MemoFaissPool(_FaissPool):
|
||||
r"""
|
||||
临时向量库的缓存池
|
||||
"""
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
@ -126,7 +163,10 @@ class MemoFaissPool(_FaissPool):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
||||
# create an empty vector store
|
||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||||
vector_store = self.new_temp_vector_store(endpoint_host=endpoint_host,
|
||||
endpoint_host_key=endpoint_host_key,
|
||||
endpoint_host_proxy=endpoint_host_proxy,
|
||||
embed_model=embed_model, embed_device=embed_device)
|
||||
item.obj = vector_store
|
||||
item.finish_loading()
|
||||
else:
|
||||
@ -136,40 +176,40 @@ class MemoFaissPool(_FaissPool):
|
||||
|
||||
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
||||
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time, random
|
||||
from pprint import pprint
|
||||
|
||||
kb_names = ["vs1", "vs2", "vs3"]
|
||||
# for name in kb_names:
|
||||
# memo_faiss_pool.load_vector_store(name)
|
||||
|
||||
def worker(vs_name: str, name: str):
|
||||
vs_name = "samples"
|
||||
time.sleep(random.randint(1, 5))
|
||||
embeddings = load_local_embeddings()
|
||||
r = random.randint(1, 3)
|
||||
|
||||
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||||
if r == 1: # add docs
|
||||
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||||
pprint(ids)
|
||||
elif r == 2: # search docs
|
||||
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
||||
pprint(docs)
|
||||
if r == 3: # delete docs
|
||||
logger.warning(f"清除 {vs_name} by {name}")
|
||||
kb_faiss_pool.get(vs_name).clear()
|
||||
|
||||
threads = []
|
||||
for n in range(1, 30):
|
||||
t = threading.Thread(target=worker,
|
||||
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
|
||||
daemon=True)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import time, random
|
||||
# from pprint import pprint
|
||||
#
|
||||
# kb_names = ["vs1", "vs2", "vs3"]
|
||||
# # for name in kb_names:
|
||||
# # memo_faiss_pool.load_vector_store(name)
|
||||
#
|
||||
# def worker(vs_name: str, name: str):
|
||||
# vs_name = "samples"
|
||||
# time.sleep(random.randint(1, 5))
|
||||
# embeddings = load_local_embeddings()
|
||||
# r = random.randint(1, 3)
|
||||
#
|
||||
# with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||||
# if r == 1: # add docs
|
||||
# ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||||
# pprint(ids)
|
||||
# elif r == 2: # search docs
|
||||
# docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
||||
# pprint(docs)
|
||||
# if r == 3: # delete docs
|
||||
# logger.warning(f"清除 {vs_name} by {name}")
|
||||
# kb_faiss_pool.get(vs_name).clear()
|
||||
#
|
||||
# threads = []
|
||||
# for n in range(1, 30):
|
||||
# t = threading.Thread(target=worker,
|
||||
# kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
|
||||
# daemon=True)
|
||||
# t.start()
|
||||
# threads.append(t)
|
||||
#
|
||||
# for t in threads:
|
||||
# t.join()
|
||||
|
||||
@ -136,8 +136,7 @@ def upload_docs(
|
||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
|
||||
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
docs: Json = Form({}, description="自定义的docs,需要转为json字符串"),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
@ -238,8 +237,7 @@ def update_docs(
|
||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||
docs: Json = Body({}, description="自定义的docs,需要转为json字符串",
|
||||
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
docs: Json = Body({}, description="自定义的docs,需要转为json字符串"),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from server.db.repository.knowledge_base_repository import (
|
||||
@ -26,20 +23,9 @@ from server.knowledge_base.utils import (
|
||||
|
||||
from typing import List, Union, Dict, Optional, Tuple
|
||||
|
||||
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
|
||||
|
||||
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||
'''
|
||||
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
||||
'''
|
||||
norm = np.linalg.norm(embeddings, axis=1)
|
||||
norm = np.reshape(norm, (norm.shape[0], 1))
|
||||
norm = np.tile(norm, (1, len(embeddings[0])))
|
||||
return np.divide(embeddings, norm)
|
||||
|
||||
|
||||
class SupportedVSType:
|
||||
FAISS = 'faiss'
|
||||
MILVUS = 'milvus'
|
||||
@ -98,12 +84,6 @@ class KBService(ABC):
|
||||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
|
||||
'''
|
||||
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
|
||||
'''
|
||||
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
@ -416,33 +396,6 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingsFunAdapter(Embeddings):
|
||||
def __init__(self, embed_model: str = EMBEDDING_MODEL):
|
||||
self.embed_model = embed_model
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
|
||||
return normalize(embeddings).tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
|
||||
return normalize(embeddings).tolist()
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
|
||||
def score_threshold_process(score_threshold, k, docs):
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import shutil
|
||||
|
||||
from configs import SCORE_THRESHOLD
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from server.utils import torch_gc
|
||||
@ -55,16 +55,16 @@ class FaissKBService(KBService):
|
||||
try:
|
||||
shutil.rmtree(self.kb_path)
|
||||
except Exception:
|
||||
...
|
||||
pass
|
||||
|
||||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
) -> List[Document]:
|
||||
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
embeddings = vs.embeddings.embed_query(query)
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
@ -72,12 +72,14 @@ class FaissKBService(KBService):
|
||||
docs: List[Document],
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
|
||||
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
|
||||
metadatas=data["metadatas"],
|
||||
ids=kwargs.get("ids"))
|
||||
texts = [x.page_content for x in docs]
|
||||
metadatas = [x.metadata for x in docs]
|
||||
embeddings = vs.embeddings.embed_documents(texts)
|
||||
|
||||
ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings),
|
||||
metadatas=metadatas)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vs.save_local(self.vs_path)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
|
||||
@ -15,7 +15,7 @@ import langchain.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
from server.utils import run_in_thread_pool, get_model_worker_config
|
||||
from server.utils import run_in_thread_pool
|
||||
import json
|
||||
from typing import List, Union, Dict, Tuple, Generator
|
||||
import chardet
|
||||
@ -187,8 +187,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||
def make_text_splitter(
|
||||
splitter_name,
|
||||
chunk_size,
|
||||
chunk_overlap,
|
||||
llm_model,
|
||||
chunk_overlap
|
||||
):
|
||||
"""
|
||||
根据参数获取特定的分词器
|
||||
@ -223,10 +222,6 @@ def make_text_splitter(
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||
config = get_model_worker_config(llm_model)
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||
config.get("model_path")
|
||||
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
@ -88,7 +88,6 @@ def get_OpenAI(
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> OpenAI:
|
||||
|
||||
# TODO: 从API获取模型信息
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
@ -319,7 +318,6 @@ def list_embed_models() -> List[str]:
|
||||
return list(MODEL_PATH["embed_model"])
|
||||
|
||||
|
||||
|
||||
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||
if type in MODEL_PATH:
|
||||
paths = MODEL_PATH[type]
|
||||
@ -347,17 +345,6 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||
return path_str # THUDM/chatglm06b
|
||||
|
||||
|
||||
# 从server_config中获取服务信息
|
||||
|
||||
def get_model_worker_config(model_name: str = None) -> dict:
|
||||
'''
|
||||
加载model worker的配置项。
|
||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||
'''
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
from configs.server_config import API_SERVER
|
||||
|
||||
@ -559,15 +546,37 @@ def get_server_configs() -> Dict:
|
||||
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
||||
|
||||
|
||||
def list_online_embed_models() -> List[str]:
|
||||
def list_online_embed_models(
|
||||
endpoint_host: str,
|
||||
endpoint_host_key: str,
|
||||
endpoint_host_proxy: str
|
||||
) -> List[str]:
|
||||
ret = []
|
||||
# TODO: 从在线API获取支持的模型列表
|
||||
client = get_httpx_client(base_url=endpoint_host, proxies=endpoint_host_proxy, timeout=HTTPX_DEFAULT_TIMEOUT)
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {endpoint_host_key}",
|
||||
}
|
||||
resp = client.get("/models", headers=headers)
|
||||
if resp.status_code == 200:
|
||||
models = resp.json().get("data", [])
|
||||
for model in models:
|
||||
if "embedding" in model.get("id", None):
|
||||
ret.append(model.get("id"))
|
||||
|
||||
except Exception as e:
|
||||
msg = f"获取在线Embeddings模型列表失败:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
finally:
|
||||
client.close()
|
||||
return ret
|
||||
|
||||
|
||||
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
||||
'''
|
||||
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
||||
从缓存中本地Embeddings模型加载,可以避免多线程时竞争加载。
|
||||
'''
|
||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
||||
from configs import EMBEDDING_MODEL
|
||||
|
||||
@ -1,59 +0,0 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from configs import ONLINE_LLM_MODEL
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_model_worker_config, list_config_llm_models
|
||||
from pprint import pprint
|
||||
import pytest
|
||||
#
|
||||
#
|
||||
# workers = []
|
||||
# for x in list_config_llm_models()["online"]:
|
||||
# if x in ONLINE_LLM_MODEL and x not in workers:
|
||||
# workers.append(x)
|
||||
# print(f"all workers to test: {workers}")
|
||||
|
||||
# workers = ["fangzhou-api"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
def test_chat(worker):
|
||||
params = ApiChatParams(
|
||||
messages = [
|
||||
{"role": "user", "content": "你是谁"},
|
||||
],
|
||||
)
|
||||
print(f"\nchat with {worker} \n")
|
||||
|
||||
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
||||
for x in worker_class().do_chat(params):
|
||||
pprint(x)
|
||||
assert isinstance(x, dict)
|
||||
assert x["error_code"] == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker", workers)
|
||||
def test_embeddings(worker):
|
||||
params = ApiEmbeddingsParams(
|
||||
texts = [
|
||||
"LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。",
|
||||
"一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。",
|
||||
]
|
||||
)
|
||||
|
||||
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
||||
if worker_class.can_embedding():
|
||||
print(f"\embeddings with {worker} \n")
|
||||
resp = worker_class().do_embeddings(params)
|
||||
|
||||
pprint(resp, depth=2)
|
||||
assert resp["code"] == 200
|
||||
assert "data" in resp
|
||||
embeddings = resp["data"]
|
||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||
assert isinstance(embeddings[0][0], float)
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
Loading…
x
Reference in New Issue
Block a user