mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-03-23 02:35:52 +08:00
embeddings模块集成openai plugins插件,使用统一api调用
This commit is contained in:
parent
217bb61448
commit
307b973f26
2
.gitignore
vendored
2
.gitignore
vendored
@ -180,3 +180,5 @@ test.py
|
|||||||
configs/*.py
|
configs/*.py
|
||||||
|
|
||||||
|
|
||||||
|
/knowledge_base/samples/content/202311-D平台项目工作大纲参数,人员中间库表结构说明V1.1(1).docx
|
||||||
|
/knowledge_base/samples/content/imi_temeplate.txt
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from starlette.responses import RedirectResponse
|
|||||||
from server.chat.chat import chat
|
from server.chat.chat import chat
|
||||||
from server.chat.completion import completion
|
from server.chat.completion import completion
|
||||||
from server.chat.feedback import chat_feedback
|
from server.chat.feedback import chat_feedback
|
||||||
from server.embeddings_api import embed_texts_endpoint
|
from server.embeddings.core.embeddings_api import embed_texts_endpoint
|
||||||
|
|
||||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||||
get_server_configs, get_prompt_template)
|
get_server_configs, get_prompt_template)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from fastapi import Body, File, Form, UploadFile
|
from fastapi import Body, File, Form, UploadFile
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
|
from server.embeddings.adapter import load_temp_adapter_embeddings
|
||||||
from server.utils import (wrap_done, get_ChatOpenAI,
|
from server.utils import (wrap_done, get_ChatOpenAI,
|
||||||
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
||||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||||
@ -10,11 +11,11 @@ from typing import AsyncIterable, List, Optional
|
|||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def _parse_files_in_thread(
|
def _parse_files_in_thread(
|
||||||
files: List[UploadFile],
|
files: List[UploadFile],
|
||||||
dir: str,
|
dir: str,
|
||||||
@ -26,6 +27,7 @@ def _parse_files_in_thread(
|
|||||||
通过多线程将上传的文件保存到对应目录内。
|
通过多线程将上传的文件保存到对应目录内。
|
||||||
生成器返回保存结果:[success or error, filename, msg, docs]
|
生成器返回保存结果:[success or error, filename, msg, docs]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parse_file(file: UploadFile) -> dict:
|
def parse_file(file: UploadFile) -> dict:
|
||||||
'''
|
'''
|
||||||
保存单个文件。
|
保存单个文件。
|
||||||
@ -55,6 +57,9 @@ def _parse_files_in_thread(
|
|||||||
|
|
||||||
|
|
||||||
def upload_temp_docs(
|
def upload_temp_docs(
|
||||||
|
endpoint_host: str = Body(False, description="接入点地址"),
|
||||||
|
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||||
|
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||||
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
prev_id: str = Form(None, description="前知识库ID"),
|
prev_id: str = Form(None, description="前知识库ID"),
|
||||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||||
@ -81,7 +86,11 @@ def upload_temp_docs(
|
|||||||
else:
|
else:
|
||||||
failed_files.append({file: msg})
|
failed_files.append({file: msg})
|
||||||
|
|
||||||
with memo_faiss_pool.load_vector_store(id).acquire() as vs:
|
with memo_faiss_pool.load_vector_store(kb_name=id,
|
||||||
|
endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
|
).acquire() as vs:
|
||||||
vs.add_documents(documents)
|
vs.add_documents(documents)
|
||||||
return BaseResponse(data={"id": id, "failed_files": failed_files})
|
return BaseResponse(data={"id": id, "failed_files": failed_files})
|
||||||
|
|
||||||
@ -89,7 +98,9 @@ def upload_temp_docs(
|
|||||||
async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
knowledge_id: str = Body(..., description="临时知识库ID"),
|
knowledge_id: str = Body(..., description="临时知识库ID"),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2),
|
score_threshold: float = Body(SCORE_THRESHOLD,
|
||||||
|
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
|
||||||
|
ge=0, le=2),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
examples=[[
|
examples=[[
|
||||||
@ -105,7 +116,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default",
|
||||||
|
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
if knowledge_id not in memo_faiss_pool.keys():
|
if knowledge_id not in memo_faiss_pool.keys():
|
||||||
return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件")
|
return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件")
|
||||||
@ -127,7 +139,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
embed_func = EmbeddingsFunAdapter()
|
embed_func = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy)
|
||||||
embeddings = await embed_func.aembed_query(query)
|
embeddings = await embed_func.aembed_query(query)
|
||||||
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
||||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||||
|
|||||||
@ -12,6 +12,9 @@ class KnowledgeBaseModel(Base):
|
|||||||
kb_name = Column(String(50), comment='知识库名称')
|
kb_name = Column(String(50), comment='知识库名称')
|
||||||
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
|
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
|
||||||
vs_type = Column(String(50), comment='向量库类型')
|
vs_type = Column(String(50), comment='向量库类型')
|
||||||
|
endpoint_host = Column(String(50), comment='接入点地址')
|
||||||
|
endpoint_host_key = Column(String(50), comment='接入点key')
|
||||||
|
endpoint_host_proxy = Column(String(50), comment='接入点代理地址')
|
||||||
embed_model = Column(String(50), comment='嵌入模型名称')
|
embed_model = Column(String(50), comment='嵌入模型名称')
|
||||||
file_count = Column(Integer, default=0, comment='文件数量')
|
file_count = Column(Integer, default=0, comment='文件数量')
|
||||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||||
|
|||||||
@ -48,6 +48,16 @@ def delete_kb_from_db(session, kb_name):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def update_kb_endpoint_from_db(session, kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy):
|
||||||
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
||||||
|
if kb:
|
||||||
|
kb.endpoint_host = endpoint_host
|
||||||
|
kb.endpoint_host_key = endpoint_host_key
|
||||||
|
kb.endpoint_host_proxy = endpoint_host_proxy
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@with_session
|
@with_session
|
||||||
def get_kb_detail(session, kb_name: str) -> dict:
|
def get_kb_detail(session, kb_name: str) -> dict:
|
||||||
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
|
||||||
@ -56,6 +66,9 @@ def get_kb_detail(session, kb_name: str) -> dict:
|
|||||||
"kb_name": kb.kb_name,
|
"kb_name": kb.kb_name,
|
||||||
"kb_info": kb.kb_info,
|
"kb_info": kb.kb_info,
|
||||||
"vs_type": kb.vs_type,
|
"vs_type": kb.vs_type,
|
||||||
|
"endpoint_host": kb.endpoint_host,
|
||||||
|
"endpoint_host_key": kb.endpoint_host_key,
|
||||||
|
"endpoint_host_proxy": kb.endpoint_host_proxy,
|
||||||
"embed_model": kb.embed_model,
|
"embed_model": kb.embed_model,
|
||||||
"file_count": kb.file_count,
|
"file_count": kb.file_count,
|
||||||
"create_time": kb.create_time,
|
"create_time": kb.create_time,
|
||||||
|
|||||||
0
server/embeddings/__init__.py
Normal file
0
server/embeddings/__init__.py
Normal file
130
server/embeddings/adapter.py
Normal file
130
server/embeddings/adapter.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import List, Union, Dict
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
|
EMBEDDING_MODEL, KB_INFO)
|
||||||
|
from server.embeddings.core.embeddings_api import embed_texts, aembed_texts
|
||||||
|
from server.utils import embedding_device
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsFunAdapter(Embeddings):
|
||||||
|
_endpoint_host: str
|
||||||
|
_endpoint_host_key: str
|
||||||
|
_endpoint_host_proxy: str
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
):
|
||||||
|
self._endpoint_host = endpoint_host
|
||||||
|
self._endpoint_host_key = endpoint_host_key
|
||||||
|
self._endpoint_host_proxy = endpoint_host_proxy
|
||||||
|
self.embed_model = embed_model
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
embeddings = embed_texts(texts=texts,
|
||||||
|
endpoint_host=self._endpoint_host,
|
||||||
|
endpoint_host_key=self._endpoint_host_key,
|
||||||
|
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||||
|
embed_model=self.embed_model,
|
||||||
|
to_query=False).data
|
||||||
|
return self._normalize(embeddings=embeddings).tolist()
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
embeddings = embed_texts(texts=[text],
|
||||||
|
endpoint_host=self._endpoint_host,
|
||||||
|
endpoint_host_key=self._endpoint_host_key,
|
||||||
|
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||||
|
embed_model=self.embed_model,
|
||||||
|
to_query=True).data
|
||||||
|
query_embed = embeddings[0]
|
||||||
|
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||||
|
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
|
||||||
|
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
embeddings = (await aembed_texts(texts=texts,
|
||||||
|
endpoint_host=self._endpoint_host,
|
||||||
|
endpoint_host_key=self._endpoint_host_key,
|
||||||
|
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||||
|
embed_model=self.embed_model,
|
||||||
|
to_query=False)).data
|
||||||
|
return self._normalize(embeddings=embeddings).tolist()
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
embeddings = (await aembed_texts(texts=[text],
|
||||||
|
endpoint_host=self._endpoint_host,
|
||||||
|
endpoint_host_key=self._endpoint_host_key,
|
||||||
|
endpoint_host_proxy=self._endpoint_host_proxy,
|
||||||
|
embed_model=self.embed_model,
|
||||||
|
to_query=True)).data
|
||||||
|
query_embed = embeddings[0]
|
||||||
|
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||||
|
normalized_query_embed = self._normalize(embeddings=query_embed_2d)
|
||||||
|
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||||
|
'''
|
||||||
|
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
||||||
|
#TODO 此处内容处理错误
|
||||||
|
'''
|
||||||
|
norm = np.linalg.norm(embeddings, axis=1)
|
||||||
|
norm = np.reshape(norm, (norm.shape[0], 1))
|
||||||
|
norm = np.tile(norm, (1, len(embeddings[0])))
|
||||||
|
return np.divide(embeddings, norm)
|
||||||
|
|
||||||
|
|
||||||
|
def load_kb_adapter_embeddings(
|
||||||
|
kb_name: str,
|
||||||
|
embed_device: str = embedding_device(),
|
||||||
|
default_embed_model: str = EMBEDDING_MODEL,
|
||||||
|
) -> "EmbeddingsFunAdapter":
|
||||||
|
"""
|
||||||
|
加载知识库配置的Embeddings模型
|
||||||
|
本地模型最终会通过load_embeddings加载
|
||||||
|
在线模型会在适配器中直接返回
|
||||||
|
:param kb_name:
|
||||||
|
:param embed_device:
|
||||||
|
:param default_embed_model:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||||
|
|
||||||
|
kb_detail = get_kb_detail(kb_name)
|
||||||
|
embed_model = kb_detail.get("embed_model", default_embed_model)
|
||||||
|
endpoint_host = kb_detail.get("endpoint_host", None)
|
||||||
|
endpoint_host_key = kb_detail.get("endpoint_host_key", None)
|
||||||
|
endpoint_host_proxy = kb_detail.get("endpoint_host_proxy", None)
|
||||||
|
|
||||||
|
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
|
embed_model=embed_model)
|
||||||
|
|
||||||
|
|
||||||
|
def load_temp_adapter_embeddings(
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
|
embed_device: str = embedding_device(),
|
||||||
|
default_embed_model: str = EMBEDDING_MODEL,
|
||||||
|
) -> "EmbeddingsFunAdapter":
|
||||||
|
"""
|
||||||
|
加载临时的Embeddings模型
|
||||||
|
本地模型最终会通过load_embeddings加载
|
||||||
|
在线模型会在适配器中直接返回
|
||||||
|
:param endpoint_host:
|
||||||
|
:param endpoint_host_key:
|
||||||
|
:param endpoint_host_proxy:
|
||||||
|
:param embed_device:
|
||||||
|
:param default_embed_model:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
return EmbeddingsFunAdapter(endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
|
embed_model=default_embed_model)
|
||||||
0
server/embeddings/core/__init__.py
Normal file
0
server/embeddings/core/__init__.py
Normal file
94
server/embeddings/core/embeddings_api.py
Normal file
94
server/embeddings/core/embeddings_api.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
from langchain.docstore.document import Document
|
||||||
|
from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE
|
||||||
|
from server.utils import BaseResponse, list_embed_models, list_online_embed_models
|
||||||
|
from fastapi import Body
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def embed_texts(
|
||||||
|
texts: List[str],
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
to_query: bool = False,
|
||||||
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||||
|
TODO: 也许需要加入缓存机制,减少 token 消耗
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||||
|
from server.utils import load_local_embeddings
|
||||||
|
embeddings = load_local_embeddings(model=embed_model)
|
||||||
|
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||||
|
|
||||||
|
# 使用在线API
|
||||||
|
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy):
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
embeddings = OpenAIEmbeddings(model=embed_model,
|
||||||
|
openai_api_key=endpoint_host_key if endpoint_host_key else "None",
|
||||||
|
openai_api_base=endpoint_host if endpoint_host else "None",
|
||||||
|
openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None,
|
||||||
|
chunk_size=CHUNK_SIZE)
|
||||||
|
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||||
|
|
||||||
|
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def aembed_texts(
|
||||||
|
texts: List[str],
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
to_query: bool = False,
|
||||||
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||||
|
from server.utils import load_local_embeddings
|
||||||
|
|
||||||
|
embeddings = load_local_embeddings(model=embed_model)
|
||||||
|
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
||||||
|
|
||||||
|
# 使用在线API
|
||||||
|
if embed_model in list_online_embed_models(endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy):
|
||||||
|
return await run_in_threadpool(embed_texts,
|
||||||
|
texts=texts,
|
||||||
|
endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
|
embed_model=embed_model,
|
||||||
|
to_query=to_query)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||||
|
|
||||||
|
|
||||||
|
def embed_texts_endpoint(
|
||||||
|
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||||
|
endpoint_host: str = Body(False, description="接入点地址"),
|
||||||
|
endpoint_host_key: str = Body(False, description="接入点key"),
|
||||||
|
endpoint_host_proxy: str = Body(False, description="接入点代理地址"),
|
||||||
|
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"),
|
||||||
|
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||||||
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
接入api,对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
||||||
|
'''
|
||||||
|
return embed_texts(texts=texts,
|
||||||
|
endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
|
embed_model=embed_model, to_query=to_query)
|
||||||
@ -1,96 +0,0 @@
|
|||||||
from langchain.docstore.document import Document
|
|
||||||
from configs import EMBEDDING_MODEL, logger
|
|
||||||
# from server.model_workers.base import ApiEmbeddingsParams
|
|
||||||
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
|
|
||||||
from fastapi import Body
|
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
online_embed_models = list_online_embed_models()
|
|
||||||
|
|
||||||
|
|
||||||
def embed_texts(
|
|
||||||
texts: List[str],
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
to_query: bool = False,
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
|
||||||
from server.utils import load_local_embeddings
|
|
||||||
|
|
||||||
embeddings = load_local_embeddings(model=embed_model)
|
|
||||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
|
||||||
|
|
||||||
if embed_model in list_online_embed_models(): # 使用在线API
|
|
||||||
config = get_model_worker_config(embed_model)
|
|
||||||
worker_class = config.get("worker_class")
|
|
||||||
embed_model = config.get("embed_model")
|
|
||||||
worker = worker_class()
|
|
||||||
if worker_class.can_embedding():
|
|
||||||
# params = ApiEmbeddingsParams(texts=texts, to_query=to_query)
|
|
||||||
resp = worker.do_embeddings(None)
|
|
||||||
return BaseResponse(**resp)
|
|
||||||
|
|
||||||
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def aembed_texts(
|
|
||||||
texts: List[str],
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
to_query: bool = False,
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
|
||||||
from server.utils import load_local_embeddings
|
|
||||||
|
|
||||||
embeddings = load_local_embeddings(model=embed_model)
|
|
||||||
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
|
||||||
|
|
||||||
if embed_model in list_online_embed_models(): # 使用在线API
|
|
||||||
return await run_in_threadpool(embed_texts,
|
|
||||||
texts=texts,
|
|
||||||
embed_model=embed_model,
|
|
||||||
to_query=to_query)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
|
||||||
|
|
||||||
|
|
||||||
def embed_texts_endpoint(
|
|
||||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
|
||||||
embed_model: str = Body(EMBEDDING_MODEL,
|
|
||||||
description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
|
|
||||||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
|
||||||
) -> BaseResponse:
|
|
||||||
'''
|
|
||||||
对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
|
||||||
'''
|
|
||||||
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
|
|
||||||
|
|
||||||
|
|
||||||
def embed_documents(
|
|
||||||
docs: List[Document],
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
|
||||||
to_query: bool = False,
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数
|
|
||||||
"""
|
|
||||||
texts = [x.page_content for x in docs]
|
|
||||||
metadatas = [x.metadata for x in docs]
|
|
||||||
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
|
|
||||||
if embeddings is not None:
|
|
||||||
return {
|
|
||||||
"texts": texts,
|
|
||||||
"embeddings": embeddings,
|
|
||||||
"metadatas": metadatas,
|
|
||||||
}
|
|
||||||
@ -3,7 +3,7 @@ from langchain.vectorstores.faiss import FAISS
|
|||||||
import threading
|
import threading
|
||||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
||||||
logger, log_verbose)
|
logger, log_verbose)
|
||||||
from server.utils import embedding_device, get_model_path, list_online_embed_models
|
from server.utils import embedding_device, get_model_path
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Any, Union, Tuple
|
from typing import List, Any, Union, Tuple
|
||||||
@ -98,26 +98,16 @@ class CachePool:
|
|||||||
else:
|
else:
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def load_kb_embeddings(
|
|
||||||
self,
|
|
||||||
kb_name: str,
|
|
||||||
embed_device: str = embedding_device(),
|
|
||||||
default_embed_model: str = EMBEDDING_MODEL,
|
|
||||||
) -> Embeddings:
|
|
||||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
|
||||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
|
||||||
|
|
||||||
kb_detail = get_kb_detail(kb_name)
|
|
||||||
embed_model = kb_detail.get("embed_model", default_embed_model)
|
|
||||||
|
|
||||||
if embed_model in list_online_embed_models():
|
|
||||||
return EmbeddingsFunAdapter(embed_model)
|
|
||||||
else:
|
|
||||||
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsPool(CachePool):
|
class EmbeddingsPool(CachePool):
|
||||||
|
|
||||||
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
||||||
|
"""
|
||||||
|
本地Embeddings模型加载
|
||||||
|
:param model:
|
||||||
|
:param device:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
model = model or EMBEDDING_MODEL
|
model = model or EMBEDDING_MODEL
|
||||||
device = embedding_device()
|
device = embedding_device()
|
||||||
@ -127,12 +117,7 @@ class EmbeddingsPool(CachePool):
|
|||||||
self.set(key, item)
|
self.set(key, item)
|
||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
if 'bge-' in model:
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
||||||
embeddings = OpenAIEmbeddings(model=model,
|
|
||||||
openai_api_key=get_model_path(model),
|
|
||||||
chunk_size=CHUNK_SIZE)
|
|
||||||
elif 'bge-' in model:
|
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
if 'zh' in model:
|
if 'zh' in model:
|
||||||
# for chinese model
|
# for chinese model
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
||||||
|
from server.embeddings.adapter import load_kb_adapter_embeddings, load_temp_adapter_embeddings
|
||||||
from server.knowledge_base.kb_cache.base import *
|
from server.knowledge_base.kb_cache.base import *
|
||||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
# from server.utils import load_local_embeddings
|
||||||
from server.utils import load_local_embeddings
|
|
||||||
from server.knowledge_base.utils import get_vs_path
|
from server.knowledge_base.utils import get_vs_path
|
||||||
from langchain.vectorstores.faiss import FAISS
|
from langchain.vectorstores.faiss import FAISS
|
||||||
from langchain.docstore.in_memory import InMemoryDocstore
|
from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
@ -52,16 +52,40 @@ class ThreadSafeFaiss(ThreadSafeObject):
|
|||||||
class _FaissPool(CachePool):
|
class _FaissPool(CachePool):
|
||||||
def new_vector_store(
|
def new_vector_store(
|
||||||
self,
|
self,
|
||||||
|
kb_name: str,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
embed_device: str = embedding_device(),
|
embed_device: str = embedding_device(),
|
||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
embeddings = EmbeddingsFunAdapter(embed_model)
|
|
||||||
|
# create an empty vector store
|
||||||
|
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
|
||||||
|
embed_device=embed_device, default_embed_model=embed_model)
|
||||||
doc = Document(page_content="init", metadata={})
|
doc = Document(page_content="init", metadata={})
|
||||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||||
ids = list(vector_store.docstore._dict.keys())
|
ids = list(vector_store.docstore._dict.keys())
|
||||||
vector_store.delete(ids)
|
vector_store.delete(ids)
|
||||||
return vector_store
|
return vector_store
|
||||||
|
|
||||||
|
def new_temp_vector_store(
|
||||||
|
self,
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
embed_device: str = embedding_device(),
|
||||||
|
) -> FAISS:
|
||||||
|
|
||||||
|
# create an empty vector store
|
||||||
|
embeddings = load_temp_adapter_embeddings(endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
|
embed_device=embed_device, default_embed_model=embed_model)
|
||||||
|
doc = Document(page_content="init", metadata={})
|
||||||
|
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||||
|
ids = list(vector_store.docstore._dict.keys())
|
||||||
|
vector_store.delete(ids)
|
||||||
|
return vector_store
|
||||||
|
|
||||||
def save_vector_store(self, kb_name: str, path: str = None):
|
def save_vector_store(self, kb_name: str, path: str = None):
|
||||||
if cache := self.get(kb_name):
|
if cache := self.get(kb_name):
|
||||||
return cache.save(path)
|
return cache.save(path)
|
||||||
@ -84,6 +108,7 @@ class KBFaissPool(_FaissPool):
|
|||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
vector_name = vector_name or embed_model
|
vector_name = vector_name or embed_model
|
||||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||||
|
try:
|
||||||
if cache is None:
|
if cache is None:
|
||||||
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
||||||
self.set((kb_name, vector_name), item)
|
self.set((kb_name, vector_name), item)
|
||||||
@ -93,13 +118,15 @@ class KBFaissPool(_FaissPool):
|
|||||||
vs_path = get_vs_path(kb_name, vector_name)
|
vs_path = get_vs_path(kb_name, vector_name)
|
||||||
|
|
||||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
|
embeddings = load_kb_adapter_embeddings(kb_name=kb_name,
|
||||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
embed_device=embed_device, default_embed_model=embed_model)
|
||||||
|
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||||
elif create:
|
elif create:
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
if not os.path.exists(vs_path):
|
if not os.path.exists(vs_path):
|
||||||
os.makedirs(vs_path)
|
os.makedirs(vs_path)
|
||||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
vector_store = self.new_vector_store(kb_name=kb_name,
|
||||||
|
embed_model=embed_model, embed_device=embed_device)
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(vs_path)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
||||||
@ -107,13 +134,23 @@ class KBFaissPool(_FaissPool):
|
|||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
else:
|
else:
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
|
except Exception as e:
|
||||||
|
self.atomic.release()
|
||||||
|
logger.error(e)
|
||||||
|
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
|
||||||
return self.get((kb_name, vector_name))
|
return self.get((kb_name, vector_name))
|
||||||
|
|
||||||
|
|
||||||
class MemoFaissPool(_FaissPool):
|
class MemoFaissPool(_FaissPool):
|
||||||
|
r"""
|
||||||
|
临时向量库的缓存池
|
||||||
|
"""
|
||||||
def load_vector_store(
|
def load_vector_store(
|
||||||
self,
|
self,
|
||||||
kb_name: str,
|
kb_name: str,
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
embed_device: str = embedding_device(),
|
embed_device: str = embedding_device(),
|
||||||
) -> ThreadSafeFaiss:
|
) -> ThreadSafeFaiss:
|
||||||
@ -126,7 +163,10 @@ class MemoFaissPool(_FaissPool):
|
|||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
vector_store = self.new_temp_vector_store(endpoint_host=endpoint_host,
|
||||||
|
endpoint_host_key=endpoint_host_key,
|
||||||
|
endpoint_host_proxy=endpoint_host_proxy,
|
||||||
|
embed_model=embed_model, embed_device=embed_device)
|
||||||
item.obj = vector_store
|
item.obj = vector_store
|
||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
else:
|
else:
|
||||||
@ -136,40 +176,40 @@ class MemoFaissPool(_FaissPool):
|
|||||||
|
|
||||||
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
||||||
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
|
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
|
||||||
|
#
|
||||||
|
#
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
import time, random
|
# import time, random
|
||||||
from pprint import pprint
|
# from pprint import pprint
|
||||||
|
#
|
||||||
kb_names = ["vs1", "vs2", "vs3"]
|
# kb_names = ["vs1", "vs2", "vs3"]
|
||||||
# for name in kb_names:
|
# # for name in kb_names:
|
||||||
# memo_faiss_pool.load_vector_store(name)
|
# # memo_faiss_pool.load_vector_store(name)
|
||||||
|
#
|
||||||
def worker(vs_name: str, name: str):
|
# def worker(vs_name: str, name: str):
|
||||||
vs_name = "samples"
|
# vs_name = "samples"
|
||||||
time.sleep(random.randint(1, 5))
|
# time.sleep(random.randint(1, 5))
|
||||||
embeddings = load_local_embeddings()
|
# embeddings = load_local_embeddings()
|
||||||
r = random.randint(1, 3)
|
# r = random.randint(1, 3)
|
||||||
|
#
|
||||||
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
# with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||||||
if r == 1: # add docs
|
# if r == 1: # add docs
|
||||||
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
# ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||||||
pprint(ids)
|
# pprint(ids)
|
||||||
elif r == 2: # search docs
|
# elif r == 2: # search docs
|
||||||
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
# docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
||||||
pprint(docs)
|
# pprint(docs)
|
||||||
if r == 3: # delete docs
|
# if r == 3: # delete docs
|
||||||
logger.warning(f"清除 {vs_name} by {name}")
|
# logger.warning(f"清除 {vs_name} by {name}")
|
||||||
kb_faiss_pool.get(vs_name).clear()
|
# kb_faiss_pool.get(vs_name).clear()
|
||||||
|
#
|
||||||
threads = []
|
# threads = []
|
||||||
for n in range(1, 30):
|
# for n in range(1, 30):
|
||||||
t = threading.Thread(target=worker,
|
# t = threading.Thread(target=worker,
|
||||||
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
|
# kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
|
||||||
daemon=True)
|
# daemon=True)
|
||||||
t.start()
|
# t.start()
|
||||||
threads.append(t)
|
# threads.append(t)
|
||||||
|
#
|
||||||
for t in threads:
|
# for t in threads:
|
||||||
t.join()
|
# t.join()
|
||||||
|
|||||||
@ -136,8 +136,7 @@ def upload_docs(
|
|||||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
|
docs: Json = Form({}, description="自定义的docs,需要转为json字符串"),
|
||||||
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
|
||||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
"""
|
"""
|
||||||
@ -238,8 +237,7 @@ def update_docs(
|
|||||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||||
docs: Json = Body({}, description="自定义的docs,需要转为json字符串",
|
docs: Json = Body({}, description="自定义的docs,需要转为json字符串"),
|
||||||
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
|
||||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,10 +1,7 @@
|
|||||||
import operator
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
from server.db.repository.knowledge_base_repository import (
|
from server.db.repository.knowledge_base_repository import (
|
||||||
@ -26,20 +23,9 @@ from server.knowledge_base.utils import (
|
|||||||
|
|
||||||
from typing import List, Union, Dict, Optional, Tuple
|
from typing import List, Union, Dict, Optional, Tuple
|
||||||
|
|
||||||
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
|
|
||||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
|
|
||||||
|
|
||||||
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
|
||||||
'''
|
|
||||||
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
|
||||||
'''
|
|
||||||
norm = np.linalg.norm(embeddings, axis=1)
|
|
||||||
norm = np.reshape(norm, (norm.shape[0], 1))
|
|
||||||
norm = np.tile(norm, (1, len(embeddings[0])))
|
|
||||||
return np.divide(embeddings, norm)
|
|
||||||
|
|
||||||
|
|
||||||
class SupportedVSType:
|
class SupportedVSType:
|
||||||
FAISS = 'faiss'
|
FAISS = 'faiss'
|
||||||
MILVUS = 'milvus'
|
MILVUS = 'milvus'
|
||||||
@ -98,12 +84,6 @@ class KBService(ABC):
|
|||||||
status = delete_kb_from_db(self.kb_name)
|
status = delete_kb_from_db(self.kb_name)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
|
|
||||||
'''
|
|
||||||
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
|
|
||||||
'''
|
|
||||||
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
|
|
||||||
|
|
||||||
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||||
"""
|
"""
|
||||||
向知识库添加文件
|
向知识库添加文件
|
||||||
@ -416,33 +396,6 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsFunAdapter(Embeddings):
|
|
||||||
def __init__(self, embed_model: str = EMBEDDING_MODEL):
|
|
||||||
self.embed_model = embed_model
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
|
|
||||||
return normalize(embeddings).tolist()
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
|
|
||||||
query_embed = embeddings[0]
|
|
||||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
|
||||||
normalized_query_embed = normalize(query_embed_2d)
|
|
||||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
|
|
||||||
return normalize(embeddings).tolist()
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
|
||||||
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
|
|
||||||
query_embed = embeddings[0]
|
|
||||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
|
||||||
normalized_query_embed = normalize(query_embed_2d)
|
|
||||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
|
||||||
|
|
||||||
|
|
||||||
def score_threshold_process(score_threshold, k, docs):
|
def score_threshold_process(score_threshold, k, docs):
|
||||||
if score_threshold is not None:
|
if score_threshold is not None:
|
||||||
cmp = (
|
cmp = (
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from configs import SCORE_THRESHOLD
|
from configs import SCORE_THRESHOLD
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc
|
||||||
@ -55,16 +55,16 @@ class FaissKBService(KBService):
|
|||||||
try:
|
try:
|
||||||
shutil.rmtree(self.kb_path)
|
shutil.rmtree(self.kb_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
...
|
pass
|
||||||
|
|
||||||
def do_search(self,
|
def do_search(self,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Document]:
|
||||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
|
||||||
embeddings = embed_func.embed_query(query)
|
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
|
embeddings = vs.embeddings.embed_query(query)
|
||||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@ -72,12 +72,14 @@ class FaissKBService(KBService):
|
|||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
|
|
||||||
|
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
|
texts = [x.page_content for x in docs]
|
||||||
metadatas=data["metadatas"],
|
metadatas = [x.metadata for x in docs]
|
||||||
ids=kwargs.get("ids"))
|
embeddings = vs.embeddings.embed_documents(texts)
|
||||||
|
|
||||||
|
ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings),
|
||||||
|
metadatas=metadatas)
|
||||||
if not kwargs.get("not_refresh_vs_cache"):
|
if not kwargs.get("not_refresh_vs_cache"):
|
||||||
vs.save_local(self.vs_path)
|
vs.save_local(self.vs_path)
|
||||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import langchain.document_loaders
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import TextSplitter
|
from langchain.text_splitter import TextSplitter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from server.utils import run_in_thread_pool, get_model_worker_config
|
from server.utils import run_in_thread_pool
|
||||||
import json
|
import json
|
||||||
from typing import List, Union, Dict, Tuple, Generator
|
from typing import List, Union, Dict, Tuple, Generator
|
||||||
import chardet
|
import chardet
|
||||||
@ -187,8 +187,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
|||||||
def make_text_splitter(
|
def make_text_splitter(
|
||||||
splitter_name,
|
splitter_name,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
chunk_overlap,
|
chunk_overlap
|
||||||
llm_model,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
根据参数获取特定的分词器
|
根据参数获取特定的分词器
|
||||||
@ -223,10 +222,6 @@ def make_text_splitter(
|
|||||||
chunk_overlap=chunk_overlap
|
chunk_overlap=chunk_overlap
|
||||||
)
|
)
|
||||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
|
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
|
||||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
|
||||||
config = get_model_worker_config(llm_model)
|
|
||||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
|
||||||
config.get("model_path")
|
|
||||||
|
|
||||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
|
|||||||
@ -88,7 +88,6 @@ def get_OpenAI(
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> OpenAI:
|
) -> OpenAI:
|
||||||
|
|
||||||
# TODO: 从API获取模型信息
|
# TODO: 从API获取模型信息
|
||||||
model = OpenAI(
|
model = OpenAI(
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
@ -319,7 +318,6 @@ def list_embed_models() -> List[str]:
|
|||||||
return list(MODEL_PATH["embed_model"])
|
return list(MODEL_PATH["embed_model"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||||
if type in MODEL_PATH:
|
if type in MODEL_PATH:
|
||||||
paths = MODEL_PATH[type]
|
paths = MODEL_PATH[type]
|
||||||
@ -347,17 +345,6 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
|||||||
return path_str # THUDM/chatglm06b
|
return path_str # THUDM/chatglm06b
|
||||||
|
|
||||||
|
|
||||||
# 从server_config中获取服务信息
|
|
||||||
|
|
||||||
def get_model_worker_config(model_name: str = None) -> dict:
|
|
||||||
'''
|
|
||||||
加载model worker的配置项。
|
|
||||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
|
||||||
'''
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def api_address() -> str:
|
def api_address() -> str:
|
||||||
from configs.server_config import API_SERVER
|
from configs.server_config import API_SERVER
|
||||||
|
|
||||||
@ -559,15 +546,37 @@ def get_server_configs() -> Dict:
|
|||||||
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
||||||
|
|
||||||
|
|
||||||
def list_online_embed_models() -> List[str]:
|
def list_online_embed_models(
|
||||||
|
endpoint_host: str,
|
||||||
|
endpoint_host_key: str,
|
||||||
|
endpoint_host_proxy: str
|
||||||
|
) -> List[str]:
|
||||||
ret = []
|
ret = []
|
||||||
# TODO: 从在线API获取支持的模型列表
|
# TODO: 从在线API获取支持的模型列表
|
||||||
|
client = get_httpx_client(base_url=endpoint_host, proxies=endpoint_host_proxy, timeout=HTTPX_DEFAULT_TIMEOUT)
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {endpoint_host_key}",
|
||||||
|
}
|
||||||
|
resp = client.get("/models", headers=headers)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
models = resp.json().get("data", [])
|
||||||
|
for model in models:
|
||||||
|
if "embedding" in model.get("id", None):
|
||||||
|
ret.append(model.get("id"))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"获取在线Embeddings模型列表失败:{e}"
|
||||||
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
|
exc_info=e if log_verbose else None)
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
||||||
'''
|
'''
|
||||||
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
从缓存中本地Embeddings模型加载,可以避免多线程时竞争加载。
|
||||||
'''
|
'''
|
||||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
from server.knowledge_base.kb_cache.base import embeddings_pool
|
||||||
from configs import EMBEDDING_MODEL
|
from configs import EMBEDDING_MODEL
|
||||||
|
|||||||
@ -1,59 +0,0 @@
|
|||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
root_path = Path(__file__).parent.parent
|
|
||||||
sys.path.append(str(root_path))
|
|
||||||
|
|
||||||
from configs import ONLINE_LLM_MODEL
|
|
||||||
from server.model_workers.base import *
|
|
||||||
from server.utils import get_model_worker_config, list_config_llm_models
|
|
||||||
from pprint import pprint
|
|
||||||
import pytest
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# workers = []
|
|
||||||
# for x in list_config_llm_models()["online"]:
|
|
||||||
# if x in ONLINE_LLM_MODEL and x not in workers:
|
|
||||||
# workers.append(x)
|
|
||||||
# print(f"all workers to test: {workers}")
|
|
||||||
|
|
||||||
# workers = ["fangzhou-api"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("worker", workers)
|
|
||||||
def test_chat(worker):
|
|
||||||
params = ApiChatParams(
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "你是谁"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
print(f"\nchat with {worker} \n")
|
|
||||||
|
|
||||||
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
|
||||||
for x in worker_class().do_chat(params):
|
|
||||||
pprint(x)
|
|
||||||
assert isinstance(x, dict)
|
|
||||||
assert x["error_code"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("worker", workers)
|
|
||||||
def test_embeddings(worker):
|
|
||||||
params = ApiEmbeddingsParams(
|
|
||||||
texts = [
|
|
||||||
"LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。",
|
|
||||||
"一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if worker_class := get_model_worker_config(worker).get("worker_class"):
|
|
||||||
if worker_class.can_embedding():
|
|
||||||
print(f"\embeddings with {worker} \n")
|
|
||||||
resp = worker_class().do_embeddings(params)
|
|
||||||
|
|
||||||
pprint(resp, depth=2)
|
|
||||||
assert resp["code"] == 200
|
|
||||||
assert "data" in resp
|
|
||||||
embeddings = resp["data"]
|
|
||||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
|
||||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
|
||||||
assert isinstance(embeddings[0][0], float)
|
|
||||||
print("向量长度:", len(embeddings[0]))
|
|
||||||
Loading…
x
Reference in New Issue
Block a user