diff --git a/.gitignore b/.gitignore index 05c4b8e2..bba08c3e 100644 --- a/.gitignore +++ b/.gitignore @@ -180,3 +180,5 @@ test.py configs/*.py +/knowledge_base/samples/content/202311-D平台项目工作大纲参数,人员中间库表结构说明V1.1(1).docx +/knowledge_base/samples/content/imi_temeplate.txt diff --git a/server/api.py b/server/api.py index d83be2f5..d5477d4d 100644 --- a/server/api.py +++ b/server/api.py @@ -16,7 +16,7 @@ from starlette.responses import RedirectResponse from server.chat.chat import chat from server.chat.completion import completion from server.chat.feedback import chat_feedback -from server.embeddings_api import embed_texts_endpoint +from server.embeddings.core.embeddings_api import embed_texts_endpoint from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs, get_prompt_template) diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index 7af3eed9..2b67a86f 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -1,8 +1,9 @@ from fastapi import Body, File, Form, UploadFile from fastapi.responses import StreamingResponse from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) +from server.embeddings.adapter import load_temp_adapter_embeddings from server.utils import (wrap_done, get_ChatOpenAI, - BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool) + BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool) from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler @@ -10,22 +11,23 @@ from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History -from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter from server.knowledge_base.utils import KnowledgeFile import json import os + def _parse_files_in_thread( - files: List[UploadFile], - dir: str, - zh_title_enhance: bool, - chunk_size: int, - chunk_overlap: int, + files: List[UploadFile], + dir: str, + zh_title_enhance: bool, + chunk_size: int, + chunk_overlap: int, ): """ 通过多线程将上传的文件保存到对应目录内。 生成器返回保存结果:[success or error, filename, msg, docs] """ + def parse_file(file: UploadFile) -> dict: ''' 保存单个文件。 @@ -55,11 +57,14 @@ def _parse_files_in_thread( def upload_temp_docs( - files: List[UploadFile] = File(..., description="上传文件,支持多文件"), - prev_id: str = Form(None, description="前知识库ID"), - chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), - chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), - zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + endpoint_host: str = Body(False, description="接入点地址"), + endpoint_host_key: str = Body(False, description="接入点key"), + endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + prev_id: str = Form(None, description="前知识库ID"), + chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), ) -> BaseResponse: ''' 将文件保存到临时目录,并进行向量化。 @@ -72,16 +77,20 @@ def upload_temp_docs( documents = [] path, id = get_temp_dir(prev_id) for success, file, msg, docs in _parse_files_in_thread(files=files, - dir=path, - zh_title_enhance=zh_title_enhance, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap): + dir=path, + zh_title_enhance=zh_title_enhance, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap): if success: documents += docs else: failed_files.append({file: msg}) - with memo_faiss_pool.load_vector_store(id).acquire() as vs: + with memo_faiss_pool.load_vector_store(kb_name=id, + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, + ).acquire() as vs: vs.add_documents(documents) return BaseResponse(data={"id": id, "failed_files": failed_files}) @@ -89,15 +98,17 @@ def upload_temp_docs( async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_id: str = Body(..., description="临时知识库ID"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2), + score_threshold: float = Body(SCORE_THRESHOLD, + description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", + ge=0, le=2), history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), stream: bool = Body(False, description="流式输出"), endpoint_host: str = Body(False, description="接入点地址"), endpoint_host_key: str = Body(False, description="接入点key"), @@ -105,8 +116,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), - prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - ): + prompt_name: str = Body("default", + description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + ): if knowledge_id not in memo_faiss_pool.keys(): return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件") @@ -127,7 +139,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= max_tokens=max_tokens, callbacks=[callback], ) - embed_func = EmbeddingsFunAdapter() + embed_func = load_temp_adapter_embeddings(endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy) embeddings = await embed_func.aembed_query(query) with memo_faiss_pool.acquire(knowledge_id) as vs: docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) @@ -156,7 +170,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n""" source_documents.append(text) - if len(source_documents) == 0: # 没有找到相关文档 + if len(source_documents) == 0: # 没有找到相关文档 source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""") if stream: diff --git a/server/db/models/knowledge_base_model.py b/server/db/models/knowledge_base_model.py index f9035af4..e5cd0809 100644 --- a/server/db/models/knowledge_base_model.py +++ b/server/db/models/knowledge_base_model.py @@ -12,6 +12,9 @@ class KnowledgeBaseModel(Base): kb_name = Column(String(50), comment='知识库名称') kb_info = Column(String(200), comment='知识库简介(用于Agent)') vs_type = Column(String(50), comment='向量库类型') + endpoint_host = Column(String(50), comment='接入点地址') + endpoint_host_key = Column(String(50), comment='接入点key') + endpoint_host_proxy = Column(String(50), comment='接入点代理地址') embed_model = Column(String(50), comment='嵌入模型名称') file_count = Column(Integer, default=0, comment='文件数量') create_time = Column(DateTime, default=func.now(), comment='创建时间') diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py index b39c8c57..d6997e4c 100644 --- a/server/db/repository/knowledge_base_repository.py +++ b/server/db/repository/knowledge_base_repository.py @@ -48,6 +48,16 @@ def delete_kb_from_db(session, kb_name): return True +@with_session +def update_kb_endpoint_from_db(session, kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy): + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() + if kb: + kb.endpoint_host = endpoint_host + kb.endpoint_host_key = endpoint_host_key + kb.endpoint_host_proxy = endpoint_host_proxy + return True + + @with_session def get_kb_detail(session, kb_name: str) -> dict: kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() @@ -56,6 +66,9 @@ def get_kb_detail(session, kb_name: str) -> dict: "kb_name": kb.kb_name, "kb_info": kb.kb_info, "vs_type": kb.vs_type, + "endpoint_host": kb.endpoint_host, + "endpoint_host_key": kb.endpoint_host_key, + "endpoint_host_proxy": kb.endpoint_host_proxy, "embed_model": kb.embed_model, "file_count": kb.file_count, "create_time": kb.create_time, diff --git a/server/embeddings/__init__.py b/server/embeddings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/embeddings/adapter.py b/server/embeddings/adapter.py new file mode 100644 index 00000000..50ccdf04 --- /dev/null +++ b/server/embeddings/adapter.py @@ -0,0 +1,130 @@ +import numpy as np +from typing import List, Union, Dict +from langchain.embeddings.base import Embeddings +from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + EMBEDDING_MODEL, KB_INFO) +from server.embeddings.core.embeddings_api import embed_texts, aembed_texts +from server.utils import embedding_device + + +class EmbeddingsFunAdapter(Embeddings): + _endpoint_host: str + _endpoint_host_key: str + _endpoint_host_proxy: str + + def __init__(self, + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, + embed_model: str = EMBEDDING_MODEL, + ): + self._endpoint_host = endpoint_host + self._endpoint_host_key = endpoint_host_key + self._endpoint_host_proxy = endpoint_host_proxy + self.embed_model = embed_model + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + embeddings = embed_texts(texts=texts, + endpoint_host=self._endpoint_host, + endpoint_host_key=self._endpoint_host_key, + endpoint_host_proxy=self._endpoint_host_proxy, + embed_model=self.embed_model, + to_query=False).data + return self._normalize(embeddings=embeddings).tolist() + + def embed_query(self, text: str) -> List[float]: + embeddings = embed_texts(texts=[text], + endpoint_host=self._endpoint_host, + endpoint_host_key=self._endpoint_host_key, + endpoint_host_proxy=self._endpoint_host_proxy, + embed_model=self.embed_model, + to_query=True).data + query_embed = embeddings[0] + query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 + normalized_query_embed = self._normalize(embeddings=query_embed_2d) + return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + embeddings = (await aembed_texts(texts=texts, + endpoint_host=self._endpoint_host, + endpoint_host_key=self._endpoint_host_key, + endpoint_host_proxy=self._endpoint_host_proxy, + embed_model=self.embed_model, + to_query=False)).data + return self._normalize(embeddings=embeddings).tolist() + + async def aembed_query(self, text: str) -> List[float]: + embeddings = (await aembed_texts(texts=[text], + endpoint_host=self._endpoint_host, + endpoint_host_key=self._endpoint_host_key, + endpoint_host_proxy=self._endpoint_host_proxy, + embed_model=self.embed_model, + to_query=True)).data + query_embed = embeddings[0] + query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 + normalized_query_embed = self._normalize(embeddings=query_embed_2d) + return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 + + @staticmethod + def _normalize(embeddings: List[List[float]]) -> np.ndarray: + ''' + sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn + #TODO 此处内容处理错误 + ''' + norm = np.linalg.norm(embeddings, axis=1) + norm = np.reshape(norm, (norm.shape[0], 1)) + norm = np.tile(norm, (1, len(embeddings[0]))) + return np.divide(embeddings, norm) + + +def load_kb_adapter_embeddings( + kb_name: str, + embed_device: str = embedding_device(), + default_embed_model: str = EMBEDDING_MODEL, +) -> "EmbeddingsFunAdapter": + """ + 加载知识库配置的Embeddings模型 + 本地模型最终会通过load_embeddings加载 + 在线模型会在适配器中直接返回 + :param kb_name: + :param embed_device: + :param default_embed_model: + :return: + """ + from server.db.repository.knowledge_base_repository import get_kb_detail + + kb_detail = get_kb_detail(kb_name) + embed_model = kb_detail.get("embed_model", default_embed_model) + endpoint_host = kb_detail.get("endpoint_host", None) + endpoint_host_key = kb_detail.get("endpoint_host_key", None) + endpoint_host_proxy = kb_detail.get("endpoint_host_proxy", None) + + return EmbeddingsFunAdapter(endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, + embed_model=embed_model) + + +def load_temp_adapter_embeddings( + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, + embed_device: str = embedding_device(), + default_embed_model: str = EMBEDDING_MODEL, +) -> "EmbeddingsFunAdapter": + """ + 加载临时的Embeddings模型 + 本地模型最终会通过load_embeddings加载 + 在线模型会在适配器中直接返回 + :param endpoint_host: + :param endpoint_host_key: + :param endpoint_host_proxy: + :param embed_device: + :param default_embed_model: + :return: + """ + + return EmbeddingsFunAdapter(endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, + embed_model=default_embed_model) diff --git a/server/embeddings/core/__init__.py b/server/embeddings/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/embeddings/core/embeddings_api.py b/server/embeddings/core/embeddings_api.py new file mode 100644 index 00000000..9f81408c --- /dev/null +++ b/server/embeddings/core/embeddings_api.py @@ -0,0 +1,94 @@ +from langchain.docstore.document import Document +from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE +from server.utils import BaseResponse, list_embed_models, list_online_embed_models +from fastapi import Body +from fastapi.concurrency import run_in_threadpool +from typing import Dict, List + + +def embed_texts( + texts: List[str], + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, +) -> BaseResponse: + ''' + 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) + TODO: 也许需要加入缓存机制,减少 token 消耗 + ''' + try: + if embed_model in list_embed_models(): # 使用本地Embeddings模型 + from server.utils import load_local_embeddings + embeddings = load_local_embeddings(model=embed_model) + return BaseResponse(data=embeddings.embed_documents(texts)) + + # 使用在线API + if embed_model in list_online_embed_models(endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy): + from langchain.embeddings.openai import OpenAIEmbeddings + embeddings = OpenAIEmbeddings(model=embed_model, + openai_api_key=endpoint_host_key if endpoint_host_key else "None", + openai_api_base=endpoint_host if endpoint_host else "None", + openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None, + chunk_size=CHUNK_SIZE) + return BaseResponse(data=embeddings.embed_documents(texts)) + + return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。") + except Exception as e: + logger.error(e) + return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") + + +async def aembed_texts( + texts: List[str], + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, +) -> BaseResponse: + ''' + 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) + ''' + try: + if embed_model in list_embed_models(): # 使用本地Embeddings模型 + from server.utils import load_local_embeddings + + embeddings = load_local_embeddings(model=embed_model) + return BaseResponse(data=await embeddings.aembed_documents(texts)) + + # 使用在线API + if embed_model in list_online_embed_models(endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy): + return await run_in_threadpool(embed_texts, + texts=texts, + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, + embed_model=embed_model, + to_query=to_query) + except Exception as e: + logger.error(e) + return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") + + +def embed_texts_endpoint( + texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), + endpoint_host: str = Body(False, description="接入点地址"), + endpoint_host_key: str = Body(False, description="接入点key"), + endpoint_host_proxy: str = Body(False, description="接入点代理地址"), + embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"), + to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"), +) -> BaseResponse: + ''' + 接入api,对文本进行向量化,返回 BaseResponse(data=List[List[float]]) + ''' + return embed_texts(texts=texts, + endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, + embed_model=embed_model, to_query=to_query) diff --git a/server/embeddings_api.py b/server/embeddings_api.py deleted file mode 100644 index d86189fb..00000000 --- a/server/embeddings_api.py +++ /dev/null @@ -1,96 +0,0 @@ -from langchain.docstore.document import Document -from configs import EMBEDDING_MODEL, logger -# from server.model_workers.base import ApiEmbeddingsParams -from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models -from fastapi import Body -from fastapi.concurrency import run_in_threadpool -from typing import Dict, List - -online_embed_models = list_online_embed_models() - - -def embed_texts( - texts: List[str], - embed_model: str = EMBEDDING_MODEL, - to_query: bool = False, -) -> BaseResponse: - ''' - 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) - ''' - try: - if embed_model in list_embed_models(): # 使用本地Embeddings模型 - from server.utils import load_local_embeddings - - embeddings = load_local_embeddings(model=embed_model) - return BaseResponse(data=embeddings.embed_documents(texts)) - - if embed_model in list_online_embed_models(): # 使用在线API - config = get_model_worker_config(embed_model) - worker_class = config.get("worker_class") - embed_model = config.get("embed_model") - worker = worker_class() - if worker_class.can_embedding(): - # params = ApiEmbeddingsParams(texts=texts, to_query=to_query) - resp = worker.do_embeddings(None) - return BaseResponse(**resp) - - return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。") - except Exception as e: - logger.error(e) - return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") - - -async def aembed_texts( - texts: List[str], - embed_model: str = EMBEDDING_MODEL, - to_query: bool = False, -) -> BaseResponse: - ''' - 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) - ''' - try: - if embed_model in list_embed_models(): # 使用本地Embeddings模型 - from server.utils import load_local_embeddings - - embeddings = load_local_embeddings(model=embed_model) - return BaseResponse(data=await embeddings.aembed_documents(texts)) - - if embed_model in list_online_embed_models(): # 使用在线API - return await run_in_threadpool(embed_texts, - texts=texts, - embed_model=embed_model, - to_query=to_query) - except Exception as e: - logger.error(e) - return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") - - -def embed_texts_endpoint( - texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), - embed_model: str = Body(EMBEDDING_MODEL, - description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"), - to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"), -) -> BaseResponse: - ''' - 对文本进行向量化,返回 BaseResponse(data=List[List[float]]) - ''' - return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query) - - -def embed_documents( - docs: List[Document], - embed_model: str = EMBEDDING_MODEL, - to_query: bool = False, -) -> Dict: - """ - 将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数 - """ - texts = [x.page_content for x in docs] - metadatas = [x.metadata for x in docs] - embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data - if embeddings is not None: - return { - "texts": texts, - "embeddings": embeddings, - "metadatas": metadatas, - } diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py index dc3291d4..96559b86 100644 --- a/server/knowledge_base/kb_cache/base.py +++ b/server/knowledge_base/kb_cache/base.py @@ -3,7 +3,7 @@ from langchain.vectorstores.faiss import FAISS import threading from configs import (EMBEDDING_MODEL, CHUNK_SIZE, logger, log_verbose) -from server.utils import embedding_device, get_model_path, list_online_embed_models +from server.utils import embedding_device, get_model_path from contextlib import contextmanager from collections import OrderedDict from typing import List, Any, Union, Tuple @@ -98,26 +98,16 @@ class CachePool: else: return cache - def load_kb_embeddings( - self, - kb_name: str, - embed_device: str = embedding_device(), - default_embed_model: str = EMBEDDING_MODEL, - ) -> Embeddings: - from server.db.repository.knowledge_base_repository import get_kb_detail - from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter - - kb_detail = get_kb_detail(kb_name) - embed_model = kb_detail.get("embed_model", default_embed_model) - - if embed_model in list_online_embed_models(): - return EmbeddingsFunAdapter(embed_model) - else: - return embeddings_pool.load_embeddings(model=embed_model, device=embed_device) - class EmbeddingsPool(CachePool): + def load_embeddings(self, model: str = None, device: str = None) -> Embeddings: + """ + 本地Embeddings模型加载 + :param model: + :param device: + :return: + """ self.atomic.acquire() model = model or EMBEDDING_MODEL device = embedding_device() @@ -127,12 +117,7 @@ class EmbeddingsPool(CachePool): self.set(key, item) with item.acquire(msg="初始化"): self.atomic.release() - if model == "text-embedding-ada-002": # openai text-embedding-ada-002 - from langchain.embeddings.openai import OpenAIEmbeddings - embeddings = OpenAIEmbeddings(model=model, - openai_api_key=get_model_path(model), - chunk_size=CHUNK_SIZE) - elif 'bge-' in model: + if 'bge-' in model: from langchain.embeddings import HuggingFaceBgeEmbeddings if 'zh' in model: # for chinese model diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 60c550ee..86d3348e 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -1,7 +1,7 @@ from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM +from server.embeddings.adapter import load_kb_adapter_embeddings, load_temp_adapter_embeddings from server.knowledge_base.kb_cache.base import * -from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter -from server.utils import load_local_embeddings +# from server.utils import load_local_embeddings from server.knowledge_base.utils import get_vs_path from langchain.vectorstores.faiss import FAISS from langchain.docstore.in_memory import InMemoryDocstore @@ -51,18 +51,42 @@ class ThreadSafeFaiss(ThreadSafeObject): class _FaissPool(CachePool): def new_vector_store( - self, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), + self, + kb_name: str, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), ) -> FAISS: - embeddings = EmbeddingsFunAdapter(embed_model) + + # create an empty vector store + embeddings = load_kb_adapter_embeddings(kb_name=kb_name, + embed_device=embed_device, default_embed_model=embed_model) doc = Document(page_content="init", metadata={}) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") ids = list(vector_store.docstore._dict.keys()) vector_store.delete(ids) return vector_store - def save_vector_store(self, kb_name: str, path: str=None): + def new_temp_vector_store( + self, + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), + ) -> FAISS: + + # create an empty vector store + embeddings = load_temp_adapter_embeddings(endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, + embed_device=embed_device, default_embed_model=embed_model) + doc = Document(page_content="init", metadata={}) + vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) + ids = list(vector_store.docstore._dict.keys()) + vector_store.delete(ids) + return vector_store + + def save_vector_store(self, kb_name: str, path: str = None): if cache := self.get(kb_name): return cache.save(path) @@ -83,39 +107,52 @@ class KBFaissPool(_FaissPool): ) -> ThreadSafeFaiss: self.atomic.acquire() vector_name = vector_name or embed_model - cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 - if cache is None: - item = ThreadSafeFaiss((kb_name, vector_name), pool=self) - self.set((kb_name, vector_name), item) - with item.acquire(msg="初始化"): - self.atomic.release() - logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") - vs_path = get_vs_path(kb_name, vector_name) + cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 + try: + if cache is None: + item = ThreadSafeFaiss((kb_name, vector_name), pool=self) + self.set((kb_name, vector_name), item) + with item.acquire(msg="初始化"): + self.atomic.release() + logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") + vs_path = get_vs_path(kb_name, vector_name) - if os.path.isfile(os.path.join(vs_path, "index.faiss")): - embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model) - vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") - elif create: - # create an empty vector store - if not os.path.exists(vs_path): - os.makedirs(vs_path) - vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device) - vector_store.save_local(vs_path) - else: - raise RuntimeError(f"knowledge base {kb_name} not exist.") - item.obj = vector_store - item.finish_loading() - else: + if os.path.isfile(os.path.join(vs_path, "index.faiss")): + embeddings = load_kb_adapter_embeddings(kb_name=kb_name, + embed_device=embed_device, default_embed_model=embed_model) + vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + elif create: + # create an empty vector store + if not os.path.exists(vs_path): + os.makedirs(vs_path) + vector_store = self.new_vector_store(kb_name=kb_name, + embed_model=embed_model, embed_device=embed_device) + vector_store.save_local(vs_path) + else: + raise RuntimeError(f"knowledge base {kb_name} not exist.") + item.obj = vector_store + item.finish_loading() + else: + self.atomic.release() + except Exception as e: self.atomic.release() + logger.error(e) + raise RuntimeError(f"向量库 {kb_name} 加载失败。") return self.get((kb_name, vector_name)) class MemoFaissPool(_FaissPool): + r""" + 临时向量库的缓存池 + """ def load_vector_store( - self, - kb_name: str, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), + self, + kb_name: str, + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), ) -> ThreadSafeFaiss: self.atomic.acquire() cache = self.get(kb_name) @@ -126,7 +163,10 @@ class MemoFaissPool(_FaissPool): self.atomic.release() logger.info(f"loading vector store in '{kb_name}' to memory.") # create an empty vector store - vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device) + vector_store = self.new_temp_vector_store(endpoint_host=endpoint_host, + endpoint_host_key=endpoint_host_key, + endpoint_host_proxy=endpoint_host_proxy, + embed_model=embed_model, embed_device=embed_device) item.obj = vector_store item.finish_loading() else: @@ -136,40 +176,40 @@ class MemoFaissPool(_FaissPool): kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM) memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM) - - -if __name__ == "__main__": - import time, random - from pprint import pprint - - kb_names = ["vs1", "vs2", "vs3"] - # for name in kb_names: - # memo_faiss_pool.load_vector_store(name) - - def worker(vs_name: str, name: str): - vs_name = "samples" - time.sleep(random.randint(1, 5)) - embeddings = load_local_embeddings() - r = random.randint(1, 3) - - with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: - if r == 1: # add docs - ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) - pprint(ids) - elif r == 2: # search docs - docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0) - pprint(docs) - if r == 3: # delete docs - logger.warning(f"清除 {vs_name} by {name}") - kb_faiss_pool.get(vs_name).clear() - - threads = [] - for n in range(1, 30): - t = threading.Thread(target=worker, - kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"}, - daemon=True) - t.start() - threads.append(t) - - for t in threads: - t.join() +# +# +# if __name__ == "__main__": +# import time, random +# from pprint import pprint +# +# kb_names = ["vs1", "vs2", "vs3"] +# # for name in kb_names: +# # memo_faiss_pool.load_vector_store(name) +# +# def worker(vs_name: str, name: str): +# vs_name = "samples" +# time.sleep(random.randint(1, 5)) +# embeddings = load_local_embeddings() +# r = random.randint(1, 3) +# +# with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: +# if r == 1: # add docs +# ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) +# pprint(ids) +# elif r == 2: # search docs +# docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0) +# pprint(docs) +# if r == 3: # delete docs +# logger.warning(f"清除 {vs_name} by {name}") +# kb_faiss_pool.get(vs_name).clear() +# +# threads = [] +# for n in range(1, 30): +# t = threading.Thread(target=worker, +# kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"}, +# daemon=True) +# t.start() +# threads.append(t) +# +# for t in threads: +# t.join() diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index e58ea41f..42799ef7 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -136,8 +136,7 @@ def upload_docs( chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), - docs: Json = Form({}, description="自定义的docs,需要转为json字符串", - examples=[{"test.txt": [Document(page_content="custom doc")]}]), + docs: Json = Form({}, description="自定义的docs,需要转为json字符串"), not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: """ @@ -238,8 +237,7 @@ def update_docs( chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), - docs: Json = Body({}, description="自定义的docs,需要转为json字符串", - examples=[{"test.txt": [Document(page_content="custom doc")]}]), + docs: Json = Body({}, description="自定义的docs,需要转为json字符串"), not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: """ diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index a6a42598..8fb6d148 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -1,10 +1,7 @@ -import operator from abc import ABC, abstractmethod import os from pathlib import Path -import numpy as np -from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document from server.db.repository.knowledge_base_repository import ( @@ -26,20 +23,9 @@ from server.knowledge_base.utils import ( from typing import List, Union, Dict, Optional, Tuple -from server.embeddings_api import embed_texts, aembed_texts, embed_documents from server.knowledge_base.model.kb_document_model import DocumentWithVSId -def normalize(embeddings: List[List[float]]) -> np.ndarray: - ''' - sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn - ''' - norm = np.linalg.norm(embeddings, axis=1) - norm = np.reshape(norm, (norm.shape[0], 1)) - norm = np.tile(norm, (1, len(embeddings[0]))) - return np.divide(embeddings, norm) - - class SupportedVSType: FAISS = 'faiss' MILVUS = 'milvus' @@ -98,12 +84,6 @@ class KBService(ABC): status = delete_kb_from_db(self.kb_name) return status - def _docs_to_embeddings(self, docs: List[Document]) -> Dict: - ''' - 将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数 - ''' - return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False) - def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 向知识库添加文件 @@ -310,7 +290,7 @@ class KBServiceFactory: return PGKBService(kb_name, embed_model=embed_model) elif SupportedVSType.MILVUS == vector_store_type: from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name,embed_model=embed_model) + return MilvusKBService(kb_name, embed_model=embed_model) elif SupportedVSType.ZILLIZ == vector_store_type: from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService return ZillizKBService(kb_name, embed_model=embed_model) @@ -416,33 +396,6 @@ def get_kb_file_details(kb_name: str) -> List[Dict]: return data -class EmbeddingsFunAdapter(Embeddings): - def __init__(self, embed_model: str = EMBEDDING_MODEL): - self.embed_model = embed_model - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data - return normalize(embeddings).tolist() - - def embed_query(self, text: str) -> List[float]: - embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data - query_embed = embeddings[0] - query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 - normalized_query_embed = normalize(query_embed_2d) - return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 - - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: - embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data - return normalize(embeddings).tolist() - - async def aembed_query(self, text: str) -> List[float]: - embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data - query_embed = embeddings[0] - query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 - normalized_query_embed = normalize(query_embed_2d) - return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 - - def score_threshold_process(score_threshold, k, docs): if score_threshold is not None: cmp = ( diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index c05f6136..65cccbed 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -2,7 +2,7 @@ import os import shutil from configs import SCORE_THRESHOLD -from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter +from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from server.utils import torch_gc @@ -14,7 +14,7 @@ class FaissKBService(KBService): vs_path: str kb_path: str vector_name: str = None - + def vs_type(self) -> str: return SupportedVSType.FAISS @@ -55,16 +55,16 @@ class FaissKBService(KBService): try: shutil.rmtree(self.kb_path) except Exception: - ... + pass def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD, - ) -> List[Tuple[Document, float]]: - embed_func = EmbeddingsFunAdapter(self.embed_model) - embeddings = embed_func.embed_query(query) + ) -> List[Document]: + with self.load_vector_store().acquire() as vs: + embeddings = vs.embeddings.embed_query(query) docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) return docs @@ -72,12 +72,14 @@ class FaissKBService(KBService): docs: List[Document], **kwargs, ) -> List[Dict]: - data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间 with self.load_vector_store().acquire() as vs: - ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]), - metadatas=data["metadatas"], - ids=kwargs.get("ids")) + texts = [x.page_content for x in docs] + metadatas = [x.metadata for x in docs] + embeddings = vs.embeddings.embed_documents(texts) + + ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings), + metadatas=metadatas) if not kwargs.get("not_refresh_vs_cache"): vs.save_local(self.vs_path) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 3c077008..7cb281e8 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -15,7 +15,7 @@ import langchain.document_loaders from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter from pathlib import Path -from server.utils import run_in_thread_pool, get_model_worker_config +from server.utils import run_in_thread_pool import json from typing import List, Union, Dict, Tuple, Generator import chardet @@ -187,8 +187,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): def make_text_splitter( splitter_name, chunk_size, - chunk_overlap, - llm_model, + chunk_overlap ): """ 根据参数获取特定的分词器 @@ -223,10 +222,6 @@ def make_text_splitter( chunk_overlap=chunk_overlap ) elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载 - if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": - config = get_model_worker_config(llm_model) - text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ - config.get("model_path") if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": from transformers import GPT2TokenizerFast diff --git a/server/utils.py b/server/utils.py index 06e12988..ab3a3dcc 100644 --- a/server/utils.py +++ b/server/utils.py @@ -88,7 +88,6 @@ def get_OpenAI( verbose: bool = True, **kwargs: Any, ) -> OpenAI: - # TODO: 从API获取模型信息 model = OpenAI( streaming=streaming, @@ -319,7 +318,6 @@ def list_embed_models() -> List[str]: return list(MODEL_PATH["embed_model"]) - def get_model_path(model_name: str, type: str = None) -> Optional[str]: if type in MODEL_PATH: paths = MODEL_PATH[type] @@ -347,17 +345,6 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]: return path_str # THUDM/chatglm06b -# 从server_config中获取服务信息 - -def get_model_worker_config(model_name: str = None) -> dict: - ''' - 加载model worker的配置项。 - 优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"] - ''' - - return {} - - def api_address() -> str: from configs.server_config import API_SERVER @@ -559,15 +546,37 @@ def get_server_configs() -> Dict: return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom} -def list_online_embed_models() -> List[str]: +def list_online_embed_models( + endpoint_host: str, + endpoint_host_key: str, + endpoint_host_proxy: str +) -> List[str]: ret = [] # TODO: 从在线API获取支持的模型列表 + client = get_httpx_client(base_url=endpoint_host, proxies=endpoint_host_proxy, timeout=HTTPX_DEFAULT_TIMEOUT) + try: + headers = { + "Authorization": f"Bearer {endpoint_host_key}", + } + resp = client.get("/models", headers=headers) + if resp.status_code == 200: + models = resp.json().get("data", []) + for model in models: + if "embedding" in model.get("id", None): + ret.append(model.get("id")) + + except Exception as e: + msg = f"获取在线Embeddings模型列表失败:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + finally: + client.close() return ret def load_local_embeddings(model: str = None, device: str = embedding_device()): ''' - 从缓存中加载embeddings,可以避免多线程时竞争加载。 + 从缓存中本地Embeddings模型加载,可以避免多线程时竞争加载。 ''' from server.knowledge_base.kb_cache.base import embeddings_pool from configs import EMBEDDING_MODEL diff --git a/tests/test_online_api.py b/tests/test_online_api.py deleted file mode 100644 index e514d472..00000000 --- a/tests/test_online_api.py +++ /dev/null @@ -1,59 +0,0 @@ -import sys -from pathlib import Path -root_path = Path(__file__).parent.parent -sys.path.append(str(root_path)) - -from configs import ONLINE_LLM_MODEL -from server.model_workers.base import * -from server.utils import get_model_worker_config, list_config_llm_models -from pprint import pprint -import pytest -# -# -# workers = [] -# for x in list_config_llm_models()["online"]: -# if x in ONLINE_LLM_MODEL and x not in workers: -# workers.append(x) -# print(f"all workers to test: {workers}") - -# workers = ["fangzhou-api"] - - -@pytest.mark.parametrize("worker", workers) -def test_chat(worker): - params = ApiChatParams( - messages = [ - {"role": "user", "content": "你是谁"}, - ], - ) - print(f"\nchat with {worker} \n") - - if worker_class := get_model_worker_config(worker).get("worker_class"): - for x in worker_class().do_chat(params): - pprint(x) - assert isinstance(x, dict) - assert x["error_code"] == 0 - - -@pytest.mark.parametrize("worker", workers) -def test_embeddings(worker): - params = ApiEmbeddingsParams( - texts = [ - "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。", - "一种利用 langchain 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。", - ] - ) - - if worker_class := get_model_worker_config(worker).get("worker_class"): - if worker_class.can_embedding(): - print(f"\embeddings with {worker} \n") - resp = worker_class().do_embeddings(params) - - pprint(resp, depth=2) - assert resp["code"] == 200 - assert "data" in resp - embeddings = resp["data"] - assert isinstance(embeddings, list) and len(embeddings) > 0 - assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 - assert isinstance(embeddings[0][0], float) - print("向量长度:", len(embeddings[0]))