支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api (#1907)

* 新功能:
- 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api
- API 增加 /other/embed_texts 接口
- init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可)

问题修复:
- API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误

开发者:
- 优化 kb_service 中 Embeddings 操作:
  - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参
  - 统一文本嵌入接口:server.embedding_api.[embed_texts, embed_documents]
This commit is contained in:
liunux4odoo 2023-10-28 23:37:30 +08:00 committed by GitHub
parent aa7c580974
commit deed92169f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 228 additions and 121 deletions

View File

@ -1,7 +1,7 @@
import sys import sys
sys.path.append(".") sys.path.append(".")
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
import nltk import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from datetime import datetime from datetime import datetime
@ -62,12 +62,20 @@ if __name__ == "__main__":
) )
) )
parser.add_argument( parser.add_argument(
"-n",
"--kb-name", "--kb-name",
type=str, type=str,
nargs="+", nargs="+",
default=[], default=[],
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.") help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
) )
parser.add_argument(
"-e",
"--embed-model",
type=str,
default=EMBEDDING_MODEL,
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
)
if len(sys.argv) <= 1: if len(sys.argv) <= 1:
parser.print_help() parser.print_help()
@ -80,11 +88,11 @@ if __name__ == "__main__":
reset_tables() reset_tables()
print("database talbes reseted") print("database talbes reseted")
print("recreating all vector stores") print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs") folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
elif args.update_in_db: elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db") folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
elif args.increament: elif args.increament:
folder2db(kb_names=args.kb_name, mode="increament") folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model)
elif args.prune_db: elif args.prune_db:
prune_db_docs(args.kb_name) prune_db_docs(args.kb_name)
elif args.prune_folder: elif args.prune_folder:

View File

@ -30,7 +30,7 @@ pytest
# online api libs # online api libs
zhipuai zhipuai
dashscope>=1.10.0 # qwen dashscope>=1.10.0 # qwen
qianfan # qianfan
# volcengine>=1.0.106 # fangzhou # volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store

View File

@ -16,6 +16,7 @@ from server.chat.chat import chat
from server.chat.openai_chat import openai_chat from server.chat.openai_chat import openai_chat
from server.chat.search_engine_chat import search_engine_chat from server.chat.search_engine_chat import search_engine_chat
from server.chat.completion import completion from server.chat.completion import completion
from server.embeddings_api import embed_texts_endpoint
from server.llm_api import (list_running_models, list_config_models, from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model, change_llm_model, stop_llm_model,
get_model_config, list_search_engines) get_model_config, list_search_engines)
@ -47,33 +48,34 @@ def create_app(run_mode: str = None):
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
mount_basic_routes(app) mount_app_routes(app, run_mode=run_mode)
if run_mode != "lite":
mount_knowledge_routes(app)
return app return app
def mount_basic_routes(app: FastAPI): def mount_app_routes(app: FastAPI, run_mode: str = None):
app.get("/", app.get("/",
response_model=BaseResponse, response_model=BaseResponse,
summary="swagger 文档")(document) summary="swagger 文档")(document)
app.post("/completion",
tags=["Completion"],
summary="要求llm模型补全(通过LLMChain)")(completion)
# Tag: Chat # Tag: Chat
app.post("/chat/fastchat", app.post("/chat/fastchat",
tags=["Chat"], tags=["Chat"],
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat) summary="与llm模型对话(直接与fastchat api对话)",
)(openai_chat)
app.post("/chat/chat", app.post("/chat/chat",
tags=["Chat"], tags=["Chat"],
summary="与llm模型对话(通过LLMChain)")(chat) summary="与llm模型对话(通过LLMChain)",
)(chat)
app.post("/chat/search_engine_chat", app.post("/chat/search_engine_chat",
tags=["Chat"], tags=["Chat"],
summary="与搜索引擎对话")(search_engine_chat) summary="与搜索引擎对话",
)(search_engine_chat)
# 知识库相关接口
if run_mode != "lite":
mount_knowledge_routes(app)
# LLM模型相关接口 # LLM模型相关接口
app.post("/llm_model/list_running_models", app.post("/llm_model/list_running_models",
@ -121,6 +123,17 @@ def mount_basic_routes(app: FastAPI):
) -> str: ) -> str:
return get_prompt_template(type=type, name=name) return get_prompt_template(type=type, name=name)
# 其它接口
app.post("/other/completion",
tags=["Other"],
summary="要求llm模型补全(通过LLMChain)",
)(completion)
app.post("/other/embed_texts",
tags=["Other"],
summary="将文本向量化,支持本地模型和在线模型",
)(embed_texts_endpoint)
def mount_knowledge_routes(app: FastAPI): def mount_knowledge_routes(app: FastAPI):
from server.chat.knowledge_base_chat import knowledge_base_chat from server.chat.knowledge_base_chat import knowledge_base_chat

68
server/embeddings_api.py Normal file
View File

@ -0,0 +1,68 @@
from langchain.docstore.document import Document
from configs import EMBEDDING_MODEL, logger
from server.model_workers.base import ApiEmbeddingsParams
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
from fastapi import Body
from typing import Dict, List
online_embed_models = list_online_embed_models()
def embed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=embeddings.embed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
config = get_model_worker_config(embed_model)
worker_class = config.get("worker_class")
worker = worker_class()
if worker_class.can_embedding():
params = ApiEmbeddingsParams(texts=texts, to_query=to_query)
resp = worker.do_embeddings(params)
return BaseResponse(**resp)
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型除了本地部署的Embedding模型也支持在线API({online_embed_models})提供的嵌入服务。"),
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
) -> BaseResponse:
'''
对文本进行向量化返回 BaseResponse(data=List[List[float]])
'''
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
def embed_documents(
docs: List[Document],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> Dict:
"""
List[Document] 向量化转化为 VectorStore.add_embeddings 可以接受的参数
"""
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
if embeddings is not None:
return {
"texts": texts,
"embeddings": embeddings,
"metadatas": metadatas,
}

View File

@ -110,7 +110,7 @@ class CachePool:
class EmbeddingsPool(CachePool): class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str, device: str) -> Embeddings: def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
self.atomic.acquire() self.atomic.acquire()
model = model or EMBEDDING_MODEL model = model or EMBEDDING_MODEL
device = device or embedding_device() device = device or embedding_device()
@ -121,7 +121,7 @@ class EmbeddingsPool(CachePool):
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002 if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(model_name=model, # TODO: 支持Azure embeddings = OpenAIEmbeddings(model_name=model,
openai_api_key=get_model_path(model), openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE) chunk_size=CHUNK_SIZE)
elif 'bge-' in model: elif 'bge-' in model:

View File

@ -1,7 +1,10 @@
from configs import CACHED_VS_NUM from configs import CACHED_VS_NUM
from server.knowledge_base.kb_cache.base import * from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.utils import load_local_embeddings
from server.knowledge_base.utils import get_vs_path from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores import FAISS from langchain.vectorstores.faiss import FAISS
from langchain.schema import Document
import os import os
from langchain.schema import Document from langchain.schema import Document
@ -38,9 +41,9 @@ class _FaissPool(CachePool):
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(), embed_device: str = embedding_device(),
) -> FAISS: ) -> FAISS:
embeddings = embeddings_pool.load_embeddings(embed_model, embed_device) # TODO: 整个Embeddings加载逻辑有些混乱待清理
# create an empty vector store # create an empty vector store
embeddings = EmbeddingsFunAdapter(embed_model)
doc = Document(page_content="init", metadata={}) doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys()) ids = list(vector_store.docstore._dict.keys())
@ -133,7 +136,7 @@ if __name__ == "__main__":
def worker(vs_name: str, name: str): def worker(vs_name: str, name: str):
vs_name = "samples" vs_name = "samples"
time.sleep(random.randint(1, 5)) time.sleep(random.randint(1, 5))
embeddings = embeddings_pool.load_embeddings() embeddings = load_local_embeddings()
r = random.randint(1, 3) r = random.randint(1, 3)
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:

View File

@ -21,12 +21,15 @@ from server.db.repository.knowledge_file_repository import (
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_MODEL, KB_INFO) EMBEDDING_MODEL, KB_INFO)
from server.knowledge_base.utils import ( from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
) )
from server.utils import embedding_device
from typing import List, Union, Dict, Optional from typing import List, Union, Dict, Optional
from server.embeddings_api import embed_texts
from server.embeddings_api import embed_documents
class SupportedVSType: class SupportedVSType:
FAISS = 'faiss' FAISS = 'faiss'
@ -48,8 +51,6 @@ class KBService(ABC):
self.kb_path = get_kb_path(self.kb_name) self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
self.do_init() self.do_init()
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def save_vector_store(self): def save_vector_store(self):
''' '''
@ -83,6 +84,12 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name) status = delete_kb_from_db(self.kb_name)
return status return status
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
'''
List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
'''
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
""" """
向知识库添加文件 向知识库添加文件
@ -149,8 +156,7 @@ class KBService(ABC):
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
): ):
embeddings = self._load_embeddings() docs = self.do_search(query, top_k, score_threshold)
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs return docs
def get_doc_by_id(self, id: str) -> Optional[Document]: def get_doc_by_id(self, id: str) -> Optional[Document]:
@ -346,24 +352,26 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
class EmbeddingsFunAdapter(Embeddings): class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embed_model: str = EMBEDDING_MODEL):
def __init__(self, embeddings: Embeddings): self.embed_model = embed_model
self.embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
return normalize(self.embeddings.embed_documents(texts)) embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data
return normalize(embeddings)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
query_embed = self.embeddings.embed_query(text) embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d) normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: # TODO: 暂不支持异步
return await normalize(self.embeddings.aembed_documents(texts)) # async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
# return normalize(await self.embeddings.aembed_documents(texts))
async def aembed_query(self, text: str) -> List[float]: # async def aembed_query(self, text: str) -> List[float]:
return await normalize(self.embeddings.aembed_query(text)) # return normalize(await self.embeddings.aembed_query(text))
def score_threshold_process(score_threshold, k, docs): def score_threshold_process(score_threshold, k, docs):

View File

@ -9,10 +9,9 @@ from configs import (
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
from langchain.embeddings.base import Embeddings
from typing import List, Dict, Optional
from langchain.docstore.document import Document
from server.utils import torch_gc from server.utils import torch_gc
from langchain.docstore.document import Document
from typing import List, Dict, Optional
class FaissKBService(KBService): class FaissKBService(KBService):
@ -58,7 +57,6 @@ class FaissKBService(KBService):
query: str, query: str,
top_k: int, top_k: int,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]: ) -> List[Document]:
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
@ -68,8 +66,11 @@ class FaissKBService(KBService):
docs: List[Document], docs: List[Document],
**kwargs, **kwargs,
) -> List[Dict]: ) -> List[Dict]:
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
ids = vs.add_documents(docs) ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
metadatas=data["metadatas"])
if not kwargs.get("not_refresh_vs_cache"): if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path) vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]

View File

@ -1,8 +1,7 @@
from typing import List, Dict, Optional from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores import Milvus from langchain.vectorstores.milvus import Milvus
from configs import kbs_config from configs import kbs_config
@ -46,10 +45,8 @@ class MilvusKBService(KBService):
def vs_type(self) -> str: def vs_type(self) -> str:
return SupportedVSType.MILVUS return SupportedVSType.MILVUS
def _load_milvus(self, embeddings: Embeddings = None): def _load_milvus(self):
if embeddings is None: self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
embeddings = self._load_embeddings()
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self): def do_init(self):
@ -60,8 +57,8 @@ class MilvusKBService(KBService):
self.milvus.col.release() self.milvus.col.release()
self.milvus.col.drop() self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) self._load_milvus()
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:

View File

@ -1,28 +1,22 @@
import json import json
from typing import List, Dict, Optional from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores import PGVector from langchain.vectorstores.pgvector import PGVector, DistanceStrategy
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text from sqlalchemy import text
from configs import kbs_config from configs import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService): class PGKBService(KBService):
pg_vector: PGVector pg_vector: PGVector
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None): def _load_pg_vector(self):
_embeddings = embeddings self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
collection_name=self.kb_name, collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN, distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri")) connection_string=kbs_config.get("pg").get("connection_uri"))
@ -57,8 +51,8 @@ class PGKBService(KBService):
''')) '''))
connect.commit() connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_pg_vector(embeddings=embeddings) self._load_pg_vector()
return score_threshold_process(score_threshold, top_k, return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k)) self.pg_vector.similarity_search_with_score(query, top_k))

View File

@ -43,11 +43,9 @@ class ZillizKBService(KBService):
def vs_type(self) -> str: def vs_type(self) -> str:
return SupportedVSType.ZILLIZ return SupportedVSType.ZILLIZ
def _load_zilliz(self, embeddings: Embeddings = None): def _load_zilliz(self):
if embeddings is None:
embeddings = self._load_embeddings()
zilliz_args = kbs_config.get("zilliz") zilliz_args = kbs_config.get("zilliz")
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(embeddings), self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, connection_args=zilliz_args) collection_name=self.kb_name, connection_args=zilliz_args)
@ -59,8 +57,8 @@ class ZillizKBService(KBService):
self.zilliz.col.release() self.zilliz.col.release()
self.zilliz.col.drop() self.zilliz.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz(embeddings=EmbeddingsFunAdapter(embeddings)) self._load_zilliz()
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k)) return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:

View File

@ -1,8 +1,4 @@
import os import os
import sys
sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat")
from transformers import AutoTokenizer
from configs import ( from configs import (
EMBEDDING_MODEL, EMBEDDING_MODEL,
KB_ROOT_PATH, KB_ROOT_PATH,
@ -22,7 +18,6 @@ from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
from pathlib import Path from pathlib import Path
import json import json
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
import io import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
@ -62,14 +57,6 @@ def list_files_from_folder(kb_name: str):
if os.path.isfile(os.path.join(doc_path, file))] if os.path.isfile(os.path.join(doc_path, file))]
def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()):
'''
从缓存中加载embeddings可以避免多线程时竞争加载
'''
from server.knowledge_base.kb_cache.base import embeddings_pool
return embeddings_pool.load_embeddings(model=model, device=device)
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'], "UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"], "CustomJSONLoader": [".json"],
@ -239,6 +226,7 @@ def make_text_splitter(
from langchain.text_splitter import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
else: ## 字符长度加载 else: ## 字符长度加载
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
text_splitter_dict[splitter_name]["tokenizer_name_or_path"], text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
trust_remote_code=True) trust_remote_code=True)
@ -358,7 +346,6 @@ def files2docs_in_thread(
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE, zh_title_enhance: bool = ZH_TITLE_ENHANCE,
pool: ThreadPoolExecutor = None,
) -> Generator: ) -> Generator:
''' '''
利用多线程批量将磁盘文件转化成langchain Document. 利用多线程批量将磁盘文件转化成langchain Document.
@ -396,7 +383,7 @@ def files2docs_in_thread(
except Exception as e: except Exception as e:
yield False, (kb_name, filename, str(e)) yield False, (kb_name, filename, str(e))
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool): for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
yield result yield result

View File

@ -2,6 +2,7 @@ from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config) get_httpx_client, get_model_worker_config)
from copy import deepcopy
def list_running_models( def list_running_models(
@ -31,16 +32,16 @@ def list_config_models() -> BaseResponse:
''' '''
从本地获取configs中配置的模型列表 从本地获取configs中配置的模型列表
''' '''
configs = list_config_llm_models() configs = {}
# 删除ONLINE_MODEL配置中的敏感信息 # 删除ONLINE_MODEL配置中的敏感信息
for config in configs["online"].values(): for name, config in list_config_llm_models()["online"].items():
del_keys = set(["worker_class"]) configs[name] = {}
for k in config: for k, v in config.items():
if "key" in k.lower() or "secret" in k.lower(): if not (k == "worker_class"
del_keys.add(k) or "key" in k.lower()
for k in del_keys: or "secret" in k.lower()
config.pop(k, None) or k.lower().endswith("id")):
configs[name][k] = v
return BaseResponse(data=configs) return BaseResponse(data=configs)
@ -51,14 +52,14 @@ def get_model_config(
''' '''
获取LLM模型配置项合并后的 获取LLM模型配置项合并后的
''' '''
config = get_model_worker_config(model_name=model_name) config = {}
# 删除ONLINE_MODEL配置中的敏感信息 # 删除ONLINE_MODEL配置中的敏感信息
del_keys = set(["worker_class"]) for k, v in get_model_worker_config(model_name=model_name).items():
for k in config: if not (k == "worker_class"
if "key" in k.lower() or "secret" in k.lower(): or "key" in k.lower()
del_keys.add(k) or "secret" in k.lower()
for k in del_keys: or k.lower().endswith("id")):
config.pop(k, None) config[k] = v
return BaseResponse(data=config) return BaseResponse(data=config)

View File

@ -180,7 +180,7 @@ class ApiModelWorker(BaseModelWorker):
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
''' '''
执行Embeddings的方法默认使用模块里面的embed_documents函数 执行Embeddings的方法默认使用模块里面的embed_documents函数
要求返回形式{"code": int, "embeddings": List[List[float]], "msg": str} 要求返回形式{"code": int, "data": List[List[float]], "msg": str}
''' '''
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"} return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}

View File

@ -99,7 +99,7 @@ class MiniMaxWorker(ApiModelWorker):
with get_httpx_client() as client: with get_httpx_client() as client:
r = client.post(url, headers=headers, json=data).json() r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"): if embeddings := r.get("vectors"):
return {"code": 200, "embeddings": embeddings} return {"code": 200, "data": embeddings}
elif error := r.get("base_resp"): elif error := r.get("base_resp"):
return {"code": error["status_code"], "msg": error["status_msg"]} return {"code": error["status_code"], "msg": error["status_msg"]}

View File

@ -172,7 +172,7 @@ class QianFanWorker(ApiModelWorker):
resp = client.post(url, json={"input": params.texts}).json() resp = client.post(url, json={"input": params.texts}).json()
if "error_cdoe" not in resp: if "error_cdoe" not in resp:
embeddings = [x["embedding"] for x in resp.get("data", [])] embeddings = [x["embedding"] for x in resp.get("data", [])]
return {"code": 200, "embeddings": embeddings} return {"code": 200, "data": embeddings}
else: else:
return {"code": resp["error_code"], "msg": resp["error_msg"]} return {"code": resp["error_code"], "msg": resp["error_msg"]}

View File

@ -68,7 +68,7 @@ class QwenWorker(ApiModelWorker):
return {"code": resp["status_code"], "msg": resp.message} return {"code": resp["status_code"], "msg": resp.message}
else: else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]] embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
return {"code": 200, "embeddings": embeddings} return {"code": 200, "data": embeddings}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -59,7 +59,7 @@ class ChatGLMWorker(ApiModelWorker):
except Exception as e: except Exception as e:
return {"code": 500, "msg": f"对文本向量化时出错:{e}"} return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return {"code": 200, "embeddings": embeddings} return {"code": 200, "data": embeddings}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -14,7 +14,6 @@ from langchain.llms import OpenAI, AzureOpenAI, Anthropic
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
thread_pool = ThreadPoolExecutor(os.cpu_count())
async def wrap_done(fn: Awaitable, event: asyncio.Event): async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -368,8 +367,8 @@ def MakeFastAPIOffline(
redoc_favicon_url=favicon, redoc_favicon_url=favicon,
) )
# 从model_config中获取模型信息
# 从model_config中获取模型信息
def list_embed_models() -> List[str]: def list_embed_models() -> List[str]:
''' '''
@ -432,8 +431,8 @@ def get_model_worker_config(model_name: str = None) -> dict:
from server import model_workers from server import model_workers
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(ONLINE_LLM_MODEL.get(model_name, {})) config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
if model_name in ONLINE_LLM_MODEL: if model_name in ONLINE_LLM_MODEL:
@ -611,21 +610,19 @@ def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
def run_in_thread_pool( def run_in_thread_pool(
func: Callable, func: Callable,
params: List[Dict] = [], params: List[Dict] = [],
pool: ThreadPoolExecutor = None,
) -> Generator: ) -> Generator:
''' '''
在线程池中批量运行任务并将运行结果以生成器的形式返回 在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数 请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
''' '''
tasks = [] tasks = []
pool = pool or thread_pool with ThreadPoolExecutor() as pool:
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
for kwargs in params: for obj in as_completed(tasks): # TODO: Ctrl+c无法停止
thread = pool.submit(func, **kwargs) yield obj.result()
tasks.append(thread)
for obj in as_completed(tasks):
yield obj.result()
def get_httpx_client( def get_httpx_client(
@ -703,7 +700,6 @@ def get_server_configs() -> Dict:
) )
from configs.model_config import ( from configs.model_config import (
LLM_MODEL, LLM_MODEL,
EMBEDDING_MODEL,
HISTORY_LEN, HISTORY_LEN,
TEMPERATURE, TEMPERATURE,
) )
@ -728,3 +724,14 @@ def list_online_embed_models() -> List[str]:
if worker_class is not None and worker_class.can_embedding(): if worker_class is not None and worker_class.can_embedding():
ret.append(k) ret.append(k)
return ret return ret
def load_local_embeddings(model: str = None, device: str = embedding_device()):
'''
从缓存中加载embeddings可以避免多线程时竞争加载
'''
from server.knowledge_base.kb_cache.base import embeddings_pool
from configs import EMBEDDING_MODEL
model = model or EMBEDDING_MODEL
return embeddings_pool.load_embeddings(model=model, device=device)

View File

@ -16,6 +16,8 @@ for x in list_config_llm_models()["online"]:
workers.append(x) workers.append(x)
print(f"all workers to test: {workers}") print(f"all workers to test: {workers}")
# workers = ["qianfan-api"]
@pytest.mark.parametrize("worker", workers) @pytest.mark.parametrize("worker", workers)
def test_chat(worker): def test_chat(worker):
@ -49,8 +51,8 @@ def test_embeddings(worker):
pprint(resp, depth=2) pprint(resp, depth=2)
assert resp["code"] == 200 assert resp["code"] == 200
assert "embeddings" in resp assert "data" in resp
embeddings = resp["embeddings"] embeddings = resp["data"]
assert isinstance(embeddings, list) and len(embeddings) > 0 assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float) assert isinstance(embeddings[0][0], float)

View File

@ -9,7 +9,7 @@ from typing import Literal, Dict, Tuple
from configs import (kbs_config, from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE, EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models from server.utils import list_embed_models, list_online_embed_models
import os import os
import time import time
@ -103,7 +103,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
key="vs_type", key="vs_type",
) )
embed_models = list_embed_models() embed_models = list_embed_models() + list_online_embed_models()
embed_model = cols[1].selectbox( embed_model = cols[1].selectbox(
"Embedding 模型", "Embedding 模型",

View File

@ -853,6 +853,26 @@ class ApiRequest:
else: else:
return ret_sync() return ret_sync()
def embed_texts(
self,
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> List[List[float]]:
'''
对文本进行向量化可选模型包括本地 embed_models 和支持 embeddings 的在线模型
'''
data = {
"texts": texts,
"embed_model": embed_model,
"to_query": to_query,
}
resp = self.post(
"/other/embed_texts",
json=data,
)
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
class AsyncApiRequest(ApiRequest): class AsyncApiRequest(ApiRequest):
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT): def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):