embeddings模块集成openai plugins插件,使用统一api调用

This commit is contained in:
glide-the 2024-01-25 02:00:02 +08:00 committed by liunux4odoo
parent 217bb61448
commit 307b973f26
18 changed files with 447 additions and 364 deletions

2
.gitignore vendored
View File

@ -180,3 +180,5 @@ test.py
configs/*.py
/knowledge_base/samples/content/202311-D平台项目工作大纲参数人员中间库表结构说明V1.1(1).docx
/knowledge_base/samples/content/imi_temeplate.txt

View File

@ -16,7 +16,7 @@ from starlette.responses import RedirectResponse
from server.chat.chat import chat
from server.chat.completion import completion
from server.chat.feedback import chat_feedback
from server.embeddings_api import embed_texts_endpoint
from server.embeddings.core.embeddings_api import embed_texts_endpoint
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
get_server_configs, get_prompt_template)

View File

@ -1,8 +1,9 @@
from fastapi import Body, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.embeddings.adapter import load_temp_adapter_embeddings
from server.utils import (wrap_done, get_ChatOpenAI,
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
@ -10,22 +11,23 @@ from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.knowledge_base.utils import KnowledgeFile
import json
import os
def _parse_files_in_thread(
files: List[UploadFile],
dir: str,
zh_title_enhance: bool,
chunk_size: int,
chunk_overlap: int,
files: List[UploadFile],
dir: str,
zh_title_enhance: bool,
chunk_size: int,
chunk_overlap: int,
):
"""
通过多线程将上传的文件保存到对应目录内
生成器返回保存结果[success or error, filename, msg, docs]
"""
def parse_file(file: UploadFile) -> dict:
'''
保存单个文件
@ -55,11 +57,14 @@ def _parse_files_in_thread(
def upload_temp_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
prev_id: str = Form(None, description="前知识库ID"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
prev_id: str = Form(None, description="前知识库ID"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse:
'''
将文件保存到临时目录并进行向量化
@ -72,16 +77,20 @@ def upload_temp_docs(
documents = []
path, id = get_temp_dir(prev_id)
for success, file, msg, docs in _parse_files_in_thread(files=files,
dir=path,
zh_title_enhance=zh_title_enhance,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap):
dir=path,
zh_title_enhance=zh_title_enhance,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap):
if success:
documents += docs
else:
failed_files.append({file: msg})
with memo_faiss_pool.load_vector_store(id).acquire() as vs:
with memo_faiss_pool.load_vector_store(kb_name=id,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
).acquire() as vs:
vs.add_documents(documents)
return BaseResponse(data={"id": id, "failed_files": failed_files})
@ -89,15 +98,17 @@ def upload_temp_docs(
async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_id: str = Body(..., description="临时知识库ID"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=2),
score_threshold: float = Body(SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右",
ge=0, le=2),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
@ -105,8 +116,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if knowledge_id not in memo_faiss_pool.keys():
return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件")
@ -127,7 +139,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
max_tokens=max_tokens,
callbacks=[callback],
)
embed_func = EmbeddingsFunAdapter()
embed_func = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy)
embeddings = await embed_func.aembed_query(query)
with memo_faiss_pool.acquire(knowledge_id) as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
@ -156,7 +170,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
if stream:

View File

@ -12,6 +12,9 @@ class KnowledgeBaseModel(Base):
kb_name = Column(String(50), comment='知识库名称')
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
vs_type = Column(String(50), comment='向量库类型')
endpoint_host = Column(String(50), comment='接入点地址')
endpoint_host_key = Column(String(50), comment='接入点key')
endpoint_host_proxy = Column(String(50), comment='接入点代理地址')
embed_model = Column(String(50), comment='嵌入模型名称')
file_count = Column(Integer, default=0, comment='文件数量')
create_time = Column(DateTime, default=func.now(), comment='创建时间')

View File

@ -48,6 +48,16 @@ def delete_kb_from_db(session, kb_name):
return True
@with_session
def update_kb_endpoint_from_db(session, kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
if kb:
kb.endpoint_host = endpoint_host
kb.endpoint_host_key = endpoint_host_key
kb.endpoint_host_proxy = endpoint_host_proxy
return True
@with_session
def get_kb_detail(session, kb_name: str) -> dict:
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
@ -56,6 +66,9 @@ def get_kb_detail(session, kb_name: str) -> dict:
"kb_name": kb.kb_name,
"kb_info": kb.kb_info,
"vs_type": kb.vs_type,
"endpoint_host": kb.endpoint_host,
"endpoint_host_key": kb.endpoint_host_key,
"endpoint_host_proxy": kb.endpoint_host_proxy,
"embed_model": kb.embed_model,
"file_count": kb.file_count,
"create_time": kb.create_time,

View File

View File

@ -0,0 +1,130 @@
import numpy as np
from typing import List, Union, Dict
from langchain.embeddings.base import Embeddings
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_MODEL, KB_INFO)
from server.embeddings.core.embeddings_api import embed_texts, aembed_texts
from server.utils import embedding_device
class EmbeddingsFunAdapter(Embeddings):
_endpoint_host: str
_endpoint_host_key: str
_endpoint_host_proxy: str
def __init__(self,
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
):
self._endpoint_host = endpoint_host
self._endpoint_host_key = endpoint_host_key
self._endpoint_host_proxy = endpoint_host_proxy
self.embed_model = embed_model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = embed_texts(texts=texts,
endpoint_host=self._endpoint_host,
endpoint_host_key=self._endpoint_host_key,
endpoint_host_proxy=self._endpoint_host_proxy,
embed_model=self.embed_model,
to_query=False).data
return self._normalize(embeddings=embeddings).tolist()
def embed_query(self, text: str) -> List[float]:
embeddings = embed_texts(texts=[text],
endpoint_host=self._endpoint_host,
endpoint_host_key=self._endpoint_host_key,
endpoint_host_proxy=self._endpoint_host_proxy,
embed_model=self.embed_model,
to_query=True).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = (await aembed_texts(texts=texts,
endpoint_host=self._endpoint_host,
endpoint_host_key=self._endpoint_host_key,
endpoint_host_proxy=self._endpoint_host_proxy,
embed_model=self.embed_model,
to_query=False)).data
return self._normalize(embeddings=embeddings).tolist()
async def aembed_query(self, text: str) -> List[float]:
embeddings = (await aembed_texts(texts=[text],
endpoint_host=self._endpoint_host,
endpoint_host_key=self._endpoint_host_key,
endpoint_host_proxy=self._endpoint_host_proxy,
embed_model=self.embed_model,
to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
@staticmethod
def _normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代使用 L2避免安装 scipy, scikit-learn
#TODO 此处内容处理错误
'''
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
def load_kb_adapter_embeddings(
kb_name: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> "EmbeddingsFunAdapter":
"""
加载知识库配置的Embeddings模型
本地模型最终会通过load_embeddings加载
在线模型会在适配器中直接返回
:param kb_name:
:param embed_device:
:param default_embed_model:
:return:
"""
from server.db.repository.knowledge_base_repository import get_kb_detail
kb_detail = get_kb_detail(kb_name)
embed_model = kb_detail.get("embed_model", default_embed_model)
endpoint_host = kb_detail.get("endpoint_host", None)
endpoint_host_key = kb_detail.get("endpoint_host_key", None)
endpoint_host_proxy = kb_detail.get("endpoint_host_proxy", None)
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model)
def load_temp_adapter_embeddings(
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> "EmbeddingsFunAdapter":
"""
加载临时的Embeddings模型
本地模型最终会通过load_embeddings加载
在线模型会在适配器中直接返回
:param endpoint_host:
:param endpoint_host_key:
:param endpoint_host_proxy:
:param embed_device:
:param default_embed_model:
:return:
"""
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=default_embed_model)

View File

View File

@ -0,0 +1,94 @@
from langchain.docstore.document import Document
from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE
from server.utils import BaseResponse, list_embed_models, list_online_embed_models
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict, List
def embed_texts(
texts: List[str],
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
TODO: 也许需要加入缓存机制减少 token 消耗
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=embeddings.embed_documents(texts))
# 使用在线API
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy):
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model=embed_model,
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
openai_api_base=endpoint_host if endpoint_host else "None",
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
chunk_size=CHUNK_SIZE)
return BaseResponse(data=embeddings.embed_documents(texts))
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
async def aembed_texts(
texts: List[str],
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=await embeddings.aembed_documents(texts))
# 使用在线API
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy):
return await run_in_threadpool(embed_texts,
texts=texts,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model,
to_query=to_query)
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
endpoint_host: str = Body(False, description="接入点地址"),
endpoint_host_key: str = Body(False, description="接入点key"),
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
) -> BaseResponse:
'''
接入api对文本进行向量化返回 BaseResponse(data=List[List[float]])
'''
return embed_texts(texts=texts,
endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model, to_query=to_query)

View File

@ -1,96 +0,0 @@
from langchain.docstore.document import Document
from configs import EMBEDDING_MODEL, logger
# from server.model_workers.base import ApiEmbeddingsParams
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict, List
online_embed_models = list_online_embed_models()
def embed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=embeddings.embed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
config = get_model_worker_config(embed_model)
worker_class = config.get("worker_class")
embed_model = config.get("embed_model")
worker = worker_class()
if worker_class.can_embedding():
# params = ApiEmbeddingsParams(texts=texts, to_query=to_query)
resp = worker.do_embeddings(None)
return BaseResponse(**resp)
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
async def aembed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=await embeddings.aembed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
return await run_in_threadpool(embed_texts,
texts=texts,
embed_model=embed_model,
to_query=to_query)
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
embed_model: str = Body(EMBEDDING_MODEL,
description=f"使用的嵌入模型除了本地部署的Embedding模型也支持在线API({online_embed_models})提供的嵌入服务。"),
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
) -> BaseResponse:
'''
对文本进行向量化返回 BaseResponse(data=List[List[float]])
'''
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
def embed_documents(
docs: List[Document],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> Dict:
"""
List[Document] 向量化转化为 VectorStore.add_embeddings 可以接受的参数
"""
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
if embeddings is not None:
return {
"texts": texts,
"embeddings": embeddings,
"metadatas": metadatas,
}

View File

@ -3,7 +3,7 @@ from langchain.vectorstores.faiss import FAISS
import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose)
from server.utils import embedding_device, get_model_path, list_online_embed_models
from server.utils import embedding_device, get_model_path
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
@ -98,26 +98,16 @@ class CachePool:
else:
return cache
def load_kb_embeddings(
self,
kb_name: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
kb_detail = get_kb_detail(kb_name)
embed_model = kb_detail.get("embed_model", default_embed_model)
if embed_model in list_online_embed_models():
return EmbeddingsFunAdapter(embed_model)
else:
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
"""
本地Embeddings模型加载
:param model:
:param device:
:return:
"""
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = embedding_device()
@ -127,12 +117,7 @@ class EmbeddingsPool(CachePool):
self.set(key, item)
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model=model,
openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
if 'bge-' in model:
from langchain.embeddings import HuggingFaceBgeEmbeddings
if 'zh' in model:
# for chinese model

View File

@ -1,7 +1,7 @@
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
from server.embeddings.adapter import load_kb_adapter_embeddings, load_temp_adapter_embeddings
from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.utils import load_local_embeddings
# from server.utils import load_local_embeddings
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores.faiss import FAISS
from langchain.docstore.in_memory import InMemoryDocstore
@ -51,18 +51,42 @@ class ThreadSafeFaiss(ThreadSafeObject):
class _FaissPool(CachePool):
def new_vector_store(
self,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
self,
kb_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS:
embeddings = EmbeddingsFunAdapter(embed_model)
# create an empty vector store
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
embed_device=embed_device, default_embed_model=embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
def save_vector_store(self, kb_name: str, path: str=None):
def new_temp_vector_store(
self,
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS:
# create an empty vector store
embeddings = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_device=embed_device, default_embed_model=embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
def save_vector_store(self, kb_name: str, path: str = None):
if cache := self.get(kb_name):
return cache.save(path)
@ -83,39 +107,52 @@ class KBFaissPool(_FaissPool):
) -> ThreadSafeFaiss:
self.atomic.acquire()
vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
try:
if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(kb_name=kb_name,
embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
except Exception as e:
self.atomic.release()
logger.error(e)
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
return self.get((kb_name, vector_name))
class MemoFaissPool(_FaissPool):
r"""
临时向量库的缓存池
"""
def load_vector_store(
self,
kb_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
self,
kb_name: str,
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
@ -126,7 +163,10 @@ class MemoFaissPool(_FaissPool):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' to memory.")
# create an empty vector store
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
vector_store = self.new_temp_vector_store(endpoint_host=endpoint_host,
endpoint_host_key=endpoint_host_key,
endpoint_host_proxy=endpoint_host_proxy,
embed_model=embed_model, embed_device=embed_device)
item.obj = vector_store
item.finish_loading()
else:
@ -136,40 +176,40 @@ class MemoFaissPool(_FaissPool):
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
if __name__ == "__main__":
import time, random
from pprint import pprint
kb_names = ["vs1", "vs2", "vs3"]
# for name in kb_names:
# memo_faiss_pool.load_vector_store(name)
def worker(vs_name: str, name: str):
vs_name = "samples"
time.sleep(random.randint(1, 5))
embeddings = load_local_embeddings()
r = random.randint(1, 3)
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
if r == 1: # add docs
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids)
elif r == 2: # search docs
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
pprint(docs)
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
daemon=True)
t.start()
threads.append(t)
for t in threads:
t.join()
#
#
# if __name__ == "__main__":
# import time, random
# from pprint import pprint
#
# kb_names = ["vs1", "vs2", "vs3"]
# # for name in kb_names:
# # memo_faiss_pool.load_vector_store(name)
#
# def worker(vs_name: str, name: str):
# vs_name = "samples"
# time.sleep(random.randint(1, 5))
# embeddings = load_local_embeddings()
# r = random.randint(1, 3)
#
# with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
# if r == 1: # add docs
# ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
# pprint(ids)
# elif r == 2: # search docs
# docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
# pprint(docs)
# if r == 3: # delete docs
# logger.warning(f"清除 {vs_name} by {name}")
# kb_faiss_pool.get(vs_name).clear()
#
# threads = []
# for n in range(1, 30):
# t = threading.Thread(target=worker,
# kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
# daemon=True)
# t.start()
# threads.append(t)
#
# for t in threads:
# t.join()

View File

@ -136,8 +136,7 @@ def upload_docs(
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
docs: Json = Form({}, description="自定义的docs需要转为json字符串"),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
"""
@ -238,8 +237,7 @@ def update_docs(
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
docs: Json = Body({}, description="自定义的docs需要转为json字符串"),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
"""

View File

@ -1,10 +1,7 @@
import operator
from abc import ABC, abstractmethod
import os
from pathlib import Path
import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from server.db.repository.knowledge_base_repository import (
@ -26,20 +23,9 @@ from server.knowledge_base.utils import (
from typing import List, Union, Dict, Optional, Tuple
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代使用 L2避免安装 scipy, scikit-learn
'''
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
class SupportedVSType:
FAISS = 'faiss'
MILVUS = 'milvus'
@ -98,12 +84,6 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
'''
List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
'''
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
@ -310,7 +290,7 @@ class KBServiceFactory:
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,embed_model=embed_model)
return MilvusKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.ZILLIZ == vector_store_type:
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model)
@ -416,33 +396,6 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
return data
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embed_model: str = EMBEDDING_MODEL):
self.embed_model = embed_model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
return normalize(embeddings).tolist()
def embed_query(self, text: str) -> List[float]:
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
return normalize(embeddings).tolist()
async def aembed_query(self, text: str) -> List[float]:
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
def score_threshold_process(score_threshold, k, docs):
if score_threshold is not None:
cmp = (

View File

@ -2,7 +2,7 @@ import os
import shutil
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc
@ -55,16 +55,16 @@ class FaissKBService(KBService):
try:
shutil.rmtree(self.kb_path)
except Exception:
...
pass
def do_search(self,
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
) -> List[Document]:
with self.load_vector_store().acquire() as vs:
embeddings = vs.embeddings.embed_query(query)
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs
@ -72,12 +72,14 @@ class FaissKBService(KBService):
docs: List[Document],
**kwargs,
) -> List[Dict]:
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
with self.load_vector_store().acquire() as vs:
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
metadatas=data["metadatas"],
ids=kwargs.get("ids"))
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
embeddings = vs.embeddings.embed_documents(texts)
ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings),
metadatas=metadatas)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]

View File

@ -15,7 +15,7 @@ import langchain.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
from server.utils import run_in_thread_pool, get_model_worker_config
from server.utils import run_in_thread_pool
import json
from typing import List, Union, Dict, Tuple, Generator
import chardet
@ -187,8 +187,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
def make_text_splitter(
splitter_name,
chunk_size,
chunk_overlap,
llm_model,
chunk_overlap
):
"""
根据参数获取特定的分词器
@ -223,10 +222,6 @@ def make_text_splitter(
chunk_overlap=chunk_overlap
)
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
config = get_model_worker_config(llm_model)
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
config.get("model_path")
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
from transformers import GPT2TokenizerFast

View File

@ -88,7 +88,6 @@ def get_OpenAI(
verbose: bool = True,
**kwargs: Any,
) -> OpenAI:
# TODO: 从API获取模型信息
model = OpenAI(
streaming=streaming,
@ -319,7 +318,6 @@ def list_embed_models() -> List[str]:
return list(MODEL_PATH["embed_model"])
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
if type in MODEL_PATH:
paths = MODEL_PATH[type]
@ -347,17 +345,6 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
return path_str # THUDM/chatglm06b
# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = None) -> dict:
'''
加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
return {}
def api_address() -> str:
from configs.server_config import API_SERVER
@ -559,15 +546,37 @@ def get_server_configs() -> Dict:
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
def list_online_embed_models() -> List[str]:
def list_online_embed_models(
endpoint_host: str,
endpoint_host_key: str,
endpoint_host_proxy: str
) -> List[str]:
ret = []
# TODO: 从在线API获取支持的模型列表
client = get_httpx_client(base_url=endpoint_host, proxies=endpoint_host_proxy, timeout=HTTPX_DEFAULT_TIMEOUT)
try:
headers = {
"Authorization": f"Bearer {endpoint_host_key}",
}
resp = client.get("/models", headers=headers)
if resp.status_code == 200:
models = resp.json().get("data", [])
for model in models:
if "embedding" in model.get("id", None):
ret.append(model.get("id"))
except Exception as e:
msg = f"获取在线Embeddings模型列表失败{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
finally:
client.close()
return ret
def load_local_embeddings(model: str = None, device: str = embedding_device()):
'''
从缓存中加载embeddings可以避免多线程时竞争加载
从缓存中本地Embeddings模型加载可以避免多线程时竞争加载
'''
from server.knowledge_base.kb_cache.base import embeddings_pool
from configs import EMBEDDING_MODEL

View File

@ -1,59 +0,0 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent
sys.path.append(str(root_path))
from configs import ONLINE_LLM_MODEL
from server.model_workers.base import *
from server.utils import get_model_worker_config, list_config_llm_models
from pprint import pprint
import pytest
#
#
# workers = []
# for x in list_config_llm_models()["online"]:
# if x in ONLINE_LLM_MODEL and x not in workers:
# workers.append(x)
# print(f"all workers to test: {workers}")
# workers = ["fangzhou-api"]
@pytest.mark.parametrize("worker", workers)
def test_chat(worker):
params = ApiChatParams(
messages = [
{"role": "user", "content": "你是谁"},
],
)
print(f"\nchat with {worker} \n")
if worker_class := get_model_worker_config(worker).get("worker_class"):
for x in worker_class().do_chat(params):
pprint(x)
assert isinstance(x, dict)
assert x["error_code"] == 0
@pytest.mark.parametrize("worker", workers)
def test_embeddings(worker):
params = ApiEmbeddingsParams(
texts = [
"LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。",
"一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。",
]
)
if worker_class := get_model_worker_config(worker).get("worker_class"):
if worker_class.can_embedding():
print(f"\embeddings with {worker} \n")
resp = worker_class().do_embeddings(params)
pprint(resp, depth=2)
assert resp["code"] == 200
assert "data" in resp
embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))