mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-05 06:03:14 +08:00
支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api (#1907)
* 新功能: - 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api - API 增加 /other/embed_texts 接口 - init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可) 问题修复: - API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误 开发者: - 优化 kb_service 中 Embeddings 操作: - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参 - 统一文本嵌入接口:server.embedding_api.[embed_texts, embed_documents]
This commit is contained in:
parent
aa7c580974
commit
deed92169f
@ -1,7 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
sys.path.append(".")
|
sys.path.append(".")
|
||||||
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
|
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
|
||||||
from configs.model_config import NLTK_DATA_PATH
|
from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
|
||||||
import nltk
|
import nltk
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -62,12 +62,20 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
"-n",
|
||||||
"--kb-name",
|
"--kb-name",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=[],
|
default=[],
|
||||||
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
|
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-e",
|
||||||
|
"--embed-model",
|
||||||
|
type=str,
|
||||||
|
default=EMBEDDING_MODEL,
|
||||||
|
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
|
||||||
|
)
|
||||||
|
|
||||||
if len(sys.argv) <= 1:
|
if len(sys.argv) <= 1:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
@ -80,11 +88,11 @@ if __name__ == "__main__":
|
|||||||
reset_tables()
|
reset_tables()
|
||||||
print("database talbes reseted")
|
print("database talbes reseted")
|
||||||
print("recreating all vector stores")
|
print("recreating all vector stores")
|
||||||
folder2db(kb_names=args.kb_name, mode="recreate_vs")
|
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
||||||
elif args.update_in_db:
|
elif args.update_in_db:
|
||||||
folder2db(kb_names=args.kb_name, mode="update_in_db")
|
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
|
||||||
elif args.increament:
|
elif args.increament:
|
||||||
folder2db(kb_names=args.kb_name, mode="increament")
|
folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model)
|
||||||
elif args.prune_db:
|
elif args.prune_db:
|
||||||
prune_db_docs(args.kb_name)
|
prune_db_docs(args.kb_name)
|
||||||
elif args.prune_folder:
|
elif args.prune_folder:
|
||||||
|
|||||||
@ -30,7 +30,7 @@ pytest
|
|||||||
# online api libs
|
# online api libs
|
||||||
zhipuai
|
zhipuai
|
||||||
dashscope>=1.10.0 # qwen
|
dashscope>=1.10.0 # qwen
|
||||||
qianfan
|
# qianfan
|
||||||
# volcengine>=1.0.106 # fangzhou
|
# volcengine>=1.0.106 # fangzhou
|
||||||
|
|
||||||
# uncomment libs if you want to use corresponding vector store
|
# uncomment libs if you want to use corresponding vector store
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from server.chat.chat import chat
|
|||||||
from server.chat.openai_chat import openai_chat
|
from server.chat.openai_chat import openai_chat
|
||||||
from server.chat.search_engine_chat import search_engine_chat
|
from server.chat.search_engine_chat import search_engine_chat
|
||||||
from server.chat.completion import completion
|
from server.chat.completion import completion
|
||||||
|
from server.embeddings_api import embed_texts_endpoint
|
||||||
from server.llm_api import (list_running_models, list_config_models,
|
from server.llm_api import (list_running_models, list_config_models,
|
||||||
change_llm_model, stop_llm_model,
|
change_llm_model, stop_llm_model,
|
||||||
get_model_config, list_search_engines)
|
get_model_config, list_search_engines)
|
||||||
@ -47,33 +48,34 @@ def create_app(run_mode: str = None):
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
mount_basic_routes(app)
|
mount_app_routes(app, run_mode=run_mode)
|
||||||
if run_mode != "lite":
|
|
||||||
mount_knowledge_routes(app)
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def mount_basic_routes(app: FastAPI):
|
def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||||
app.get("/",
|
app.get("/",
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="swagger 文档")(document)
|
summary="swagger 文档")(document)
|
||||||
|
|
||||||
app.post("/completion",
|
|
||||||
tags=["Completion"],
|
|
||||||
summary="要求llm模型补全(通过LLMChain)")(completion)
|
|
||||||
|
|
||||||
# Tag: Chat
|
# Tag: Chat
|
||||||
app.post("/chat/fastchat",
|
app.post("/chat/fastchat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
|
summary="与llm模型对话(直接与fastchat api对话)",
|
||||||
|
)(openai_chat)
|
||||||
|
|
||||||
app.post("/chat/chat",
|
app.post("/chat/chat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与llm模型对话(通过LLMChain)")(chat)
|
summary="与llm模型对话(通过LLMChain)",
|
||||||
|
)(chat)
|
||||||
|
|
||||||
app.post("/chat/search_engine_chat",
|
app.post("/chat/search_engine_chat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与搜索引擎对话")(search_engine_chat)
|
summary="与搜索引擎对话",
|
||||||
|
)(search_engine_chat)
|
||||||
|
|
||||||
|
# 知识库相关接口
|
||||||
|
if run_mode != "lite":
|
||||||
|
mount_knowledge_routes(app)
|
||||||
|
|
||||||
# LLM模型相关接口
|
# LLM模型相关接口
|
||||||
app.post("/llm_model/list_running_models",
|
app.post("/llm_model/list_running_models",
|
||||||
@ -121,6 +123,17 @@ def mount_basic_routes(app: FastAPI):
|
|||||||
) -> str:
|
) -> str:
|
||||||
return get_prompt_template(type=type, name=name)
|
return get_prompt_template(type=type, name=name)
|
||||||
|
|
||||||
|
# 其它接口
|
||||||
|
app.post("/other/completion",
|
||||||
|
tags=["Other"],
|
||||||
|
summary="要求llm模型补全(通过LLMChain)",
|
||||||
|
)(completion)
|
||||||
|
|
||||||
|
app.post("/other/embed_texts",
|
||||||
|
tags=["Other"],
|
||||||
|
summary="将文本向量化,支持本地模型和在线模型",
|
||||||
|
)(embed_texts_endpoint)
|
||||||
|
|
||||||
|
|
||||||
def mount_knowledge_routes(app: FastAPI):
|
def mount_knowledge_routes(app: FastAPI):
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
|
|||||||
68
server/embeddings_api.py
Normal file
68
server/embeddings_api.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from langchain.docstore.document import Document
|
||||||
|
from configs import EMBEDDING_MODEL, logger
|
||||||
|
from server.model_workers.base import ApiEmbeddingsParams
|
||||||
|
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
|
||||||
|
from fastapi import Body
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
online_embed_models = list_online_embed_models()
|
||||||
|
|
||||||
|
|
||||||
|
def embed_texts(
|
||||||
|
texts: List[str],
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
to_query: bool = False,
|
||||||
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||||
|
from server.utils import load_local_embeddings
|
||||||
|
|
||||||
|
embeddings = load_local_embeddings(model=embed_model)
|
||||||
|
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||||
|
|
||||||
|
if embed_model in list_online_embed_models(): # 使用在线API
|
||||||
|
config = get_model_worker_config(embed_model)
|
||||||
|
worker_class = config.get("worker_class")
|
||||||
|
worker = worker_class()
|
||||||
|
if worker_class.can_embedding():
|
||||||
|
params = ApiEmbeddingsParams(texts=texts, to_query=to_query)
|
||||||
|
resp = worker.do_embeddings(params)
|
||||||
|
return BaseResponse(**resp)
|
||||||
|
|
||||||
|
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||||
|
|
||||||
|
def embed_texts_endpoint(
|
||||||
|
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||||
|
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
|
||||||
|
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||||||
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
||||||
|
'''
|
||||||
|
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_documents(
|
||||||
|
docs: List[Document],
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
to_query: bool = False,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数
|
||||||
|
"""
|
||||||
|
texts = [x.page_content for x in docs]
|
||||||
|
metadatas = [x.metadata for x in docs]
|
||||||
|
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
|
||||||
|
if embeddings is not None:
|
||||||
|
return {
|
||||||
|
"texts": texts,
|
||||||
|
"embeddings": embeddings,
|
||||||
|
"metadatas": metadatas,
|
||||||
|
}
|
||||||
@ -110,7 +110,7 @@ class CachePool:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsPool(CachePool):
|
class EmbeddingsPool(CachePool):
|
||||||
def load_embeddings(self, model: str, device: str) -> Embeddings:
|
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
model = model or EMBEDDING_MODEL
|
model = model or EMBEDDING_MODEL
|
||||||
device = device or embedding_device()
|
device = device or embedding_device()
|
||||||
@ -121,7 +121,7 @@ class EmbeddingsPool(CachePool):
|
|||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||||
embeddings = OpenAIEmbeddings(model_name=model, # TODO: 支持Azure
|
embeddings = OpenAIEmbeddings(model_name=model,
|
||||||
openai_api_key=get_model_path(model),
|
openai_api_key=get_model_path(model),
|
||||||
chunk_size=CHUNK_SIZE)
|
chunk_size=CHUNK_SIZE)
|
||||||
elif 'bge-' in model:
|
elif 'bge-' in model:
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
from configs import CACHED_VS_NUM
|
from configs import CACHED_VS_NUM
|
||||||
from server.knowledge_base.kb_cache.base import *
|
from server.knowledge_base.kb_cache.base import *
|
||||||
|
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||||
|
from server.utils import load_local_embeddings
|
||||||
from server.knowledge_base.utils import get_vs_path
|
from server.knowledge_base.utils import get_vs_path
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores.faiss import FAISS
|
||||||
|
from langchain.schema import Document
|
||||||
import os
|
import os
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
@ -38,9 +41,9 @@ class _FaissPool(CachePool):
|
|||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
embed_device: str = embedding_device(),
|
embed_device: str = embedding_device(),
|
||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
embeddings = embeddings_pool.load_embeddings(embed_model, embed_device)
|
# TODO: 整个Embeddings加载逻辑有些混乱,待清理
|
||||||
|
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
|
embeddings = EmbeddingsFunAdapter(embed_model)
|
||||||
doc = Document(page_content="init", metadata={})
|
doc = Document(page_content="init", metadata={})
|
||||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||||
ids = list(vector_store.docstore._dict.keys())
|
ids = list(vector_store.docstore._dict.keys())
|
||||||
@ -133,7 +136,7 @@ if __name__ == "__main__":
|
|||||||
def worker(vs_name: str, name: str):
|
def worker(vs_name: str, name: str):
|
||||||
vs_name = "samples"
|
vs_name = "samples"
|
||||||
time.sleep(random.randint(1, 5))
|
time.sleep(random.randint(1, 5))
|
||||||
embeddings = embeddings_pool.load_embeddings()
|
embeddings = load_local_embeddings()
|
||||||
r = random.randint(1, 3)
|
r = random.randint(1, 3)
|
||||||
|
|
||||||
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||||||
|
|||||||
@ -21,12 +21,15 @@ from server.db.repository.knowledge_file_repository import (
|
|||||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
EMBEDDING_MODEL, KB_INFO)
|
EMBEDDING_MODEL, KB_INFO)
|
||||||
from server.knowledge_base.utils import (
|
from server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
get_kb_path, get_doc_path, KnowledgeFile,
|
||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
)
|
)
|
||||||
from server.utils import embedding_device
|
|
||||||
from typing import List, Union, Dict, Optional
|
from typing import List, Union, Dict, Optional
|
||||||
|
|
||||||
|
from server.embeddings_api import embed_texts
|
||||||
|
from server.embeddings_api import embed_documents
|
||||||
|
|
||||||
|
|
||||||
class SupportedVSType:
|
class SupportedVSType:
|
||||||
FAISS = 'faiss'
|
FAISS = 'faiss'
|
||||||
@ -48,8 +51,6 @@ class KBService(ABC):
|
|||||||
self.kb_path = get_kb_path(self.kb_name)
|
self.kb_path = get_kb_path(self.kb_name)
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
self.do_init()
|
self.do_init()
|
||||||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
|
||||||
return load_embeddings(self.embed_model, embed_device)
|
|
||||||
|
|
||||||
def save_vector_store(self):
|
def save_vector_store(self):
|
||||||
'''
|
'''
|
||||||
@ -83,6 +84,12 @@ class KBService(ABC):
|
|||||||
status = delete_kb_from_db(self.kb_name)
|
status = delete_kb_from_db(self.kb_name)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
|
||||||
|
'''
|
||||||
|
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
|
||||||
|
'''
|
||||||
|
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
|
||||||
|
|
||||||
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||||
"""
|
"""
|
||||||
向知识库添加文件
|
向知识库添加文件
|
||||||
@ -149,8 +156,7 @@ class KBService(ABC):
|
|||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
):
|
):
|
||||||
embeddings = self._load_embeddings()
|
docs = self.do_search(query, top_k, score_threshold)
|
||||||
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
@ -346,24 +352,26 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsFunAdapter(Embeddings):
|
class EmbeddingsFunAdapter(Embeddings):
|
||||||
|
def __init__(self, embed_model: str = EMBEDDING_MODEL):
|
||||||
def __init__(self, embeddings: Embeddings):
|
self.embed_model = embed_model
|
||||||
self.embeddings = embeddings
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
return normalize(self.embeddings.embed_documents(texts))
|
embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
|
||||||
|
return normalize(embeddings)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
query_embed = self.embeddings.embed_query(text)
|
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
|
||||||
|
query_embed = embeddings[0]
|
||||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||||
normalized_query_embed = normalize(query_embed_2d)
|
normalized_query_embed = normalize(query_embed_2d)
|
||||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
# TODO: 暂不支持异步
|
||||||
return await normalize(self.embeddings.aembed_documents(texts))
|
# async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
# return normalize(await self.embeddings.aembed_documents(texts))
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
# async def aembed_query(self, text: str) -> List[float]:
|
||||||
return await normalize(self.embeddings.aembed_query(text))
|
# return normalize(await self.embeddings.aembed_query(text))
|
||||||
|
|
||||||
|
|
||||||
def score_threshold_process(score_threshold, k, docs):
|
def score_threshold_process(score_threshold, k, docs):
|
||||||
|
|||||||
@ -9,10 +9,9 @@ from configs import (
|
|||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
@ -58,7 +57,6 @@ class FaissKBService(KBService):
|
|||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
embeddings: Embeddings = None,
|
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||||
@ -68,8 +66,11 @@ class FaissKBService(KBService):
|
|||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
|
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
|
||||||
|
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
ids = vs.add_documents(docs)
|
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
|
||||||
|
metadatas=data["metadatas"])
|
||||||
if not kwargs.get("not_refresh_vs_cache"):
|
if not kwargs.get("not_refresh_vs_cache"):
|
||||||
vs.save_local(self.vs_path)
|
vs.save_local(self.vs_path)
|
||||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores import Milvus
|
from langchain.vectorstores.milvus import Milvus
|
||||||
|
|
||||||
from configs import kbs_config
|
from configs import kbs_config
|
||||||
|
|
||||||
@ -46,10 +45,8 @@ class MilvusKBService(KBService):
|
|||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return SupportedVSType.MILVUS
|
return SupportedVSType.MILVUS
|
||||||
|
|
||||||
def _load_milvus(self, embeddings: Embeddings = None):
|
def _load_milvus(self):
|
||||||
if embeddings is None:
|
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||||
embeddings = self._load_embeddings()
|
|
||||||
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
|
|
||||||
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
@ -60,8 +57,8 @@ class MilvusKBService(KBService):
|
|||||||
self.milvus.col.release()
|
self.milvus.col.release()
|
||||||
self.milvus.col.drop()
|
self.milvus.col.drop()
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
|
self._load_milvus()
|
||||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
|
|||||||
@ -1,28 +1,22 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores import PGVector
|
from langchain.vectorstores.pgvector import PGVector, DistanceStrategy
|
||||||
from langchain.vectorstores.pgvector import DistanceStrategy
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from configs import kbs_config
|
from configs import kbs_config
|
||||||
|
|
||||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||||
score_threshold_process
|
score_threshold_process
|
||||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
from server.utils import embedding_device as get_embedding_device
|
|
||||||
|
|
||||||
|
|
||||||
class PGKBService(KBService):
|
class PGKBService(KBService):
|
||||||
pg_vector: PGVector
|
pg_vector: PGVector
|
||||||
|
|
||||||
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
|
def _load_pg_vector(self):
|
||||||
_embeddings = embeddings
|
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||||
if _embeddings is None:
|
|
||||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
|
||||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
|
|
||||||
collection_name=self.kb_name,
|
collection_name=self.kb_name,
|
||||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||||
@ -57,8 +51,8 @@ class PGKBService(KBService):
|
|||||||
'''))
|
'''))
|
||||||
connect.commit()
|
connect.commit()
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_pg_vector(embeddings=embeddings)
|
self._load_pg_vector()
|
||||||
return score_threshold_process(score_threshold, top_k,
|
return score_threshold_process(score_threshold, top_k,
|
||||||
self.pg_vector.similarity_search_with_score(query, top_k))
|
self.pg_vector.similarity_search_with_score(query, top_k))
|
||||||
|
|
||||||
|
|||||||
@ -43,11 +43,9 @@ class ZillizKBService(KBService):
|
|||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return SupportedVSType.ZILLIZ
|
return SupportedVSType.ZILLIZ
|
||||||
|
|
||||||
def _load_zilliz(self, embeddings: Embeddings = None):
|
def _load_zilliz(self):
|
||||||
if embeddings is None:
|
|
||||||
embeddings = self._load_embeddings()
|
|
||||||
zilliz_args = kbs_config.get("zilliz")
|
zilliz_args = kbs_config.get("zilliz")
|
||||||
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(embeddings),
|
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||||
collection_name=self.kb_name, connection_args=zilliz_args)
|
collection_name=self.kb_name, connection_args=zilliz_args)
|
||||||
|
|
||||||
|
|
||||||
@ -59,8 +57,8 @@ class ZillizKBService(KBService):
|
|||||||
self.zilliz.col.release()
|
self.zilliz.col.release()
|
||||||
self.zilliz.col.drop()
|
self.zilliz.col.drop()
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_zilliz(embeddings=EmbeddingsFunAdapter(embeddings))
|
self._load_zilliz()
|
||||||
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
|
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
|
|||||||
@ -1,8 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
@ -22,7 +18,6 @@ from langchain.docstore.document import Document
|
|||||||
from langchain.text_splitter import TextSplitter
|
from langchain.text_splitter import TextSplitter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
|
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
|
||||||
import io
|
import io
|
||||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||||
@ -62,14 +57,6 @@ def list_files_from_folder(kb_name: str):
|
|||||||
if os.path.isfile(os.path.join(doc_path, file))]
|
if os.path.isfile(os.path.join(doc_path, file))]
|
||||||
|
|
||||||
|
|
||||||
def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()):
|
|
||||||
'''
|
|
||||||
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
|
||||||
'''
|
|
||||||
from server.knowledge_base.kb_cache.base import embeddings_pool
|
|
||||||
return embeddings_pool.load_embeddings(model=model, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||||
"UnstructuredMarkdownLoader": ['.md'],
|
"UnstructuredMarkdownLoader": ['.md'],
|
||||||
"CustomJSONLoader": [".json"],
|
"CustomJSONLoader": [".json"],
|
||||||
@ -239,6 +226,7 @@ def make_text_splitter(
|
|||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
else: ## 字符长度加载
|
else: ## 字符长度加载
|
||||||
|
from transformers import AutoTokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
@ -358,7 +346,6 @@ def files2docs_in_thread(
|
|||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||||
pool: ThreadPoolExecutor = None,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
'''
|
'''
|
||||||
利用多线程批量将磁盘文件转化成langchain Document.
|
利用多线程批量将磁盘文件转化成langchain Document.
|
||||||
@ -396,7 +383,7 @@ def files2docs_in_thread(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield False, (kb_name, filename, str(e))
|
yield False, (kb_name, filename, str(e))
|
||||||
|
|
||||||
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
|
for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from fastapi import Body
|
|||||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||||
get_httpx_client, get_model_worker_config)
|
get_httpx_client, get_model_worker_config)
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
def list_running_models(
|
def list_running_models(
|
||||||
@ -31,16 +32,16 @@ def list_config_models() -> BaseResponse:
|
|||||||
'''
|
'''
|
||||||
从本地获取configs中配置的模型列表
|
从本地获取configs中配置的模型列表
|
||||||
'''
|
'''
|
||||||
configs = list_config_llm_models()
|
configs = {}
|
||||||
# 删除ONLINE_MODEL配置中的敏感信息
|
# 删除ONLINE_MODEL配置中的敏感信息
|
||||||
for config in configs["online"].values():
|
for name, config in list_config_llm_models()["online"].items():
|
||||||
del_keys = set(["worker_class"])
|
configs[name] = {}
|
||||||
for k in config:
|
for k, v in config.items():
|
||||||
if "key" in k.lower() or "secret" in k.lower():
|
if not (k == "worker_class"
|
||||||
del_keys.add(k)
|
or "key" in k.lower()
|
||||||
for k in del_keys:
|
or "secret" in k.lower()
|
||||||
config.pop(k, None)
|
or k.lower().endswith("id")):
|
||||||
|
configs[name][k] = v
|
||||||
return BaseResponse(data=configs)
|
return BaseResponse(data=configs)
|
||||||
|
|
||||||
|
|
||||||
@ -51,14 +52,14 @@ def get_model_config(
|
|||||||
'''
|
'''
|
||||||
获取LLM模型配置项(合并后的)
|
获取LLM模型配置项(合并后的)
|
||||||
'''
|
'''
|
||||||
config = get_model_worker_config(model_name=model_name)
|
config = {}
|
||||||
# 删除ONLINE_MODEL配置中的敏感信息
|
# 删除ONLINE_MODEL配置中的敏感信息
|
||||||
del_keys = set(["worker_class"])
|
for k, v in get_model_worker_config(model_name=model_name).items():
|
||||||
for k in config:
|
if not (k == "worker_class"
|
||||||
if "key" in k.lower() or "secret" in k.lower():
|
or "key" in k.lower()
|
||||||
del_keys.add(k)
|
or "secret" in k.lower()
|
||||||
for k in del_keys:
|
or k.lower().endswith("id")):
|
||||||
config.pop(k, None)
|
config[k] = v
|
||||||
|
|
||||||
return BaseResponse(data=config)
|
return BaseResponse(data=config)
|
||||||
|
|
||||||
|
|||||||
@ -180,7 +180,7 @@ class ApiModelWorker(BaseModelWorker):
|
|||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||||
'''
|
'''
|
||||||
执行Embeddings的方法,默认使用模块里面的embed_documents函数。
|
执行Embeddings的方法,默认使用模块里面的embed_documents函数。
|
||||||
要求返回形式:{"code": int, "embeddings": List[List[float]], "msg": str}
|
要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
|
||||||
'''
|
'''
|
||||||
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
|
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
|
||||||
|
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class MiniMaxWorker(ApiModelWorker):
|
|||||||
with get_httpx_client() as client:
|
with get_httpx_client() as client:
|
||||||
r = client.post(url, headers=headers, json=data).json()
|
r = client.post(url, headers=headers, json=data).json()
|
||||||
if embeddings := r.get("vectors"):
|
if embeddings := r.get("vectors"):
|
||||||
return {"code": 200, "embeddings": embeddings}
|
return {"code": 200, "data": embeddings}
|
||||||
elif error := r.get("base_resp"):
|
elif error := r.get("base_resp"):
|
||||||
return {"code": error["status_code"], "msg": error["status_msg"]}
|
return {"code": error["status_code"], "msg": error["status_msg"]}
|
||||||
|
|
||||||
|
|||||||
@ -172,7 +172,7 @@ class QianFanWorker(ApiModelWorker):
|
|||||||
resp = client.post(url, json={"input": params.texts}).json()
|
resp = client.post(url, json={"input": params.texts}).json()
|
||||||
if "error_cdoe" not in resp:
|
if "error_cdoe" not in resp:
|
||||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||||
return {"code": 200, "embeddings": embeddings}
|
return {"code": 200, "data": embeddings}
|
||||||
else:
|
else:
|
||||||
return {"code": resp["error_code"], "msg": resp["error_msg"]}
|
return {"code": resp["error_code"], "msg": resp["error_msg"]}
|
||||||
|
|
||||||
|
|||||||
@ -68,7 +68,7 @@ class QwenWorker(ApiModelWorker):
|
|||||||
return {"code": resp["status_code"], "msg": resp.message}
|
return {"code": resp["status_code"], "msg": resp.message}
|
||||||
else:
|
else:
|
||||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
||||||
return {"code": 200, "embeddings": embeddings}
|
return {"code": 200, "data": embeddings}
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class ChatGLMWorker(ApiModelWorker):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||||
|
|
||||||
return {"code": 200, "embeddings": embeddings}
|
return {"code": 200, "data": embeddings}
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
|||||||
import httpx
|
import httpx
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
||||||
|
|
||||||
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
@ -368,8 +367,8 @@ def MakeFastAPIOffline(
|
|||||||
redoc_favicon_url=favicon,
|
redoc_favicon_url=favicon,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 从model_config中获取模型信息
|
|
||||||
|
|
||||||
|
# 从model_config中获取模型信息
|
||||||
|
|
||||||
def list_embed_models() -> List[str]:
|
def list_embed_models() -> List[str]:
|
||||||
'''
|
'''
|
||||||
@ -432,8 +431,8 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
from server import model_workers
|
from server import model_workers
|
||||||
|
|
||||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
|
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||||
|
|
||||||
|
|
||||||
if model_name in ONLINE_LLM_MODEL:
|
if model_name in ONLINE_LLM_MODEL:
|
||||||
@ -611,21 +610,19 @@ def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
|||||||
def run_in_thread_pool(
|
def run_in_thread_pool(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
params: List[Dict] = [],
|
params: List[Dict] = [],
|
||||||
pool: ThreadPoolExecutor = None,
|
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
'''
|
'''
|
||||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||||
'''
|
'''
|
||||||
tasks = []
|
tasks = []
|
||||||
pool = pool or thread_pool
|
with ThreadPoolExecutor() as pool:
|
||||||
|
for kwargs in params:
|
||||||
|
thread = pool.submit(func, **kwargs)
|
||||||
|
tasks.append(thread)
|
||||||
|
|
||||||
for kwargs in params:
|
for obj in as_completed(tasks): # TODO: Ctrl+c无法停止
|
||||||
thread = pool.submit(func, **kwargs)
|
yield obj.result()
|
||||||
tasks.append(thread)
|
|
||||||
|
|
||||||
for obj in as_completed(tasks):
|
|
||||||
yield obj.result()
|
|
||||||
|
|
||||||
|
|
||||||
def get_httpx_client(
|
def get_httpx_client(
|
||||||
@ -703,7 +700,6 @@ def get_server_configs() -> Dict:
|
|||||||
)
|
)
|
||||||
from configs.model_config import (
|
from configs.model_config import (
|
||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
EMBEDDING_MODEL,
|
|
||||||
HISTORY_LEN,
|
HISTORY_LEN,
|
||||||
TEMPERATURE,
|
TEMPERATURE,
|
||||||
)
|
)
|
||||||
@ -728,3 +724,14 @@ def list_online_embed_models() -> List[str]:
|
|||||||
if worker_class is not None and worker_class.can_embedding():
|
if worker_class is not None and worker_class.can_embedding():
|
||||||
ret.append(k)
|
ret.append(k)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
||||||
|
'''
|
||||||
|
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
||||||
|
'''
|
||||||
|
from server.knowledge_base.kb_cache.base import embeddings_pool
|
||||||
|
from configs import EMBEDDING_MODEL
|
||||||
|
|
||||||
|
model = model or EMBEDDING_MODEL
|
||||||
|
return embeddings_pool.load_embeddings(model=model, device=device)
|
||||||
|
|||||||
@ -16,6 +16,8 @@ for x in list_config_llm_models()["online"]:
|
|||||||
workers.append(x)
|
workers.append(x)
|
||||||
print(f"all workers to test: {workers}")
|
print(f"all workers to test: {workers}")
|
||||||
|
|
||||||
|
# workers = ["qianfan-api"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("worker", workers)
|
@pytest.mark.parametrize("worker", workers)
|
||||||
def test_chat(worker):
|
def test_chat(worker):
|
||||||
@ -49,8 +51,8 @@ def test_embeddings(worker):
|
|||||||
|
|
||||||
pprint(resp, depth=2)
|
pprint(resp, depth=2)
|
||||||
assert resp["code"] == 200
|
assert resp["code"] == 200
|
||||||
assert "embeddings" in resp
|
assert "data" in resp
|
||||||
embeddings = resp["embeddings"]
|
embeddings = resp["data"]
|
||||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||||
assert isinstance(embeddings[0][0], float)
|
assert isinstance(embeddings[0][0], float)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from typing import Literal, Dict, Tuple
|
|||||||
from configs import (kbs_config,
|
from configs import (kbs_config,
|
||||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
from server.utils import list_embed_models
|
from server.utils import list_embed_models, list_online_embed_models
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
key="vs_type",
|
key="vs_type",
|
||||||
)
|
)
|
||||||
|
|
||||||
embed_models = list_embed_models()
|
embed_models = list_embed_models() + list_online_embed_models()
|
||||||
|
|
||||||
embed_model = cols[1].selectbox(
|
embed_model = cols[1].selectbox(
|
||||||
"Embedding 模型",
|
"Embedding 模型",
|
||||||
|
|||||||
@ -853,6 +853,26 @@ class ApiRequest:
|
|||||||
else:
|
else:
|
||||||
return ret_sync()
|
return ret_sync()
|
||||||
|
|
||||||
|
def embed_texts(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
to_query: bool = False,
|
||||||
|
) -> List[List[float]]:
|
||||||
|
'''
|
||||||
|
对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型
|
||||||
|
'''
|
||||||
|
data = {
|
||||||
|
"texts": texts,
|
||||||
|
"embed_model": embed_model,
|
||||||
|
"to_query": to_query,
|
||||||
|
}
|
||||||
|
resp = self.post(
|
||||||
|
"/other/embed_texts",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
||||||
|
|
||||||
|
|
||||||
class AsyncApiRequest(ApiRequest):
|
class AsyncApiRequest(ApiRequest):
|
||||||
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user