From 5d422ca9a15bbf4453c4a04bb5eb41ae36a810dd Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Thu, 8 Feb 2024 11:30:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A8=A1=E5=9E=8B=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=96=B9=E5=BC=8F=EF=BC=8C=E6=89=80=E6=9C=89=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E4=BB=A5=20openai=20=E5=85=BC=E5=AE=B9=E6=A1=86?= =?UTF-8?q?=E6=9E=B6=E7=9A=84=E5=BD=A2=E5=BC=8F=E6=8E=A5=E5=85=A5=EF=BC=8C?= =?UTF-8?q?chatchat=20=E8=87=AA=E8=BA=AB=E4=B8=8D=E5=86=8D=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E6=A8=A1=E5=9E=8B=E3=80=82=20=E6=94=B9=E5=8F=98=20Emb?= =?UTF-8?q?eddings=20=E6=A8=A1=E5=9E=8B=E6=94=B9=E4=B8=BA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E6=A1=86=E6=9E=B6=20API=EF=BC=8C=E4=B8=8D=E5=86=8D?= =?UTF-8?q?=E6=89=8B=E5=8A=A8=E5=8A=A0=E8=BD=BD=EF=BC=8C=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89=20Embeddings=20Keyword=20=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=20=E4=BF=AE=E6=94=B9=E4=BE=9D=E8=B5=96=E6=96=87?= =?UTF-8?q?=E4=BB=B6=EF=BC=8C=E7=A7=BB=E9=99=A4=20torch=20transformers=20?= =?UTF-8?q?=E7=AD=89=E9=87=8D=E4=BE=9D=E8=B5=96=20=E6=9A=82=E6=97=B6?= =?UTF-8?q?=E7=A7=BB=E5=87=BA=E5=AF=B9=20loom=20=E7=9A=84=E9=9B=86?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 后续: 1、优化目录结构 2、检查合并中有无被覆盖的 0.2.10 内容 --- configs/basic_config.py.example | 27 +- configs/kb_config.py.example | 3 +- configs/model_config.py.example | 163 ++++---- configs/prompt_config.py.example | 26 +- configs/server_config.py.example | 5 +- embeddings/__init__.py | 0 embeddings/add_embedding_keywords.py | 79 ---- embeddings/embedding_keywords.txt | 3 - init_database.py | 8 +- requirements_api.txt | 57 --- requirements_lite.txt | 42 -- requirements_webui.txt | 9 - server/agent/tools_factory/text2image.py | 2 +- server/api.py | 30 +- server/chat/chat.py | 10 +- server/chat/completion.py | 17 +- server/chat/file_chat.py | 22 +- server/db/models/knowledge_base_model.py | 3 - .../repository/knowledge_base_repository.py | 23 +- server/embeddings/__init__.py | 0 server/embeddings/adapter.py | 130 ------- server/embeddings/core/__init__.py | 0 server/embeddings/core/embeddings_api.py | 94 ----- server/knowledge_base/kb_api.py | 9 +- server/knowledge_base/kb_cache/base.py | 50 +-- server/knowledge_base/kb_cache/faiss_cache.py | 40 +- server/knowledge_base/kb_doc_api.py | 47 +-- server/knowledge_base/kb_service/base.py | 28 +- .../kb_service/es_kb_service.py | 6 +- server/knowledge_base/kb_summary/base.py | 4 +- server/knowledge_base/kb_summary_api.py | 44 +-- server/knowledge_base/migrate.py | 4 +- server/knowledge_base/utils.py | 3 + server/localai_embeddings.py | 363 ++++++++++++++++++ server/reranker/reranker.py | 3 +- server/utils.py | 269 +++++++------ startup.py | 45 +-- webui.py | 31 +- webui_pages/dialogue/dialogue.py | 67 ++-- webui_pages/knowledge_base/knowledge_base.py | 79 +--- webui_pages/utils.py | 54 +-- 41 files changed, 757 insertions(+), 1142 deletions(-) delete mode 100644 embeddings/__init__.py delete mode 100644 embeddings/add_embedding_keywords.py delete mode 100644 embeddings/embedding_keywords.txt delete mode 100644 requirements_api.txt delete mode 100644 requirements_lite.txt delete mode 100644 requirements_webui.txt delete mode 100644 server/embeddings/__init__.py delete mode 100644 server/embeddings/adapter.py delete mode 100644 server/embeddings/core/__init__.py delete mode 100644 server/embeddings/core/embeddings_api.py create mode 100644 server/localai_embeddings.py diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example index ccba9575..c167bc1e 100644 --- a/configs/basic_config.py.example +++ b/configs/basic_config.py.example @@ -1,8 +1,8 @@ import logging import os +from pathlib import Path + import langchain -import tempfile -import shutil # 是否显示详细日志 @@ -11,6 +11,16 @@ langchain.verbose = False # 通常情况下不需要更改以下内容 +# 用户数据根目录 +DATA_PATH = (Path(__file__).absolute().parent.parent) # / "data") +if not os.path.exists(DATA_PATH): + os.mkdir(DATA_PATH) + +# nltk 模型存储路径 +NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data") +import nltk +nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + # 日志格式 LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() @@ -19,12 +29,12 @@ logging.basicConfig(format=LOG_FORMAT) # 日志存储路径 -LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") +LOG_PATH = os.path.join(DATA_PATH, "logs") if not os.path.exists(LOG_PATH): os.mkdir(LOG_PATH) # 模型生成内容(图片、视频、音频等)保存位置 -MEDIA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "media") +MEDIA_PATH = os.path.join(DATA_PATH, "media") if not os.path.exists(MEDIA_PATH): os.mkdir(MEDIA_PATH) os.mkdir(os.path.join(MEDIA_PATH, "image")) @@ -32,9 +42,6 @@ if not os.path.exists(MEDIA_PATH): os.mkdir(os.path.join(MEDIA_PATH, "video")) # 临时文件目录,主要用于文件对话 -BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") -if os.path.isdir(BASE_TEMP_DIR): - shutil.rmtree(BASE_TEMP_DIR) -os.makedirs(BASE_TEMP_DIR, exist_ok=True) - -MEDIA_PATH = None +BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp") +if not os.path.exists(BASE_TEMP_DIR): + os.mkdir(BASE_TEMP_DIR) diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index 4484165d..c4617da3 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -1,6 +1,6 @@ import os - +# 默认使用的知识库 DEFAULT_KNOWLEDGE_BASE = "samples" # 默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es @@ -44,6 +44,7 @@ KB_INFO = { KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") if not os.path.exists(KB_ROOT_PATH): os.mkdir(KB_ROOT_PATH) + # 数据库默认存储路径。 # 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。 DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") diff --git a/configs/model_config.py.example b/configs/model_config.py.example index a09def32..19e7db7a 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -1,10 +1,25 @@ import os -MODEL_ROOT_PATH = "" -EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh -EMBEDDING_DEVICE = "auto" -EMBEDDING_KEYWORD_FILE = "keywords.txt" -EMBEDDING_MODEL_OUTPUT_PATH = "output" + +# 默认选用的 LLM 名称 +DEFAULT_LLM_MODEL = "chatglm3-6b" + +# 默认选用的 Embedding 名称 +DEFAULT_EMBEDDING_MODEL = "bge-large-zh-v1.5" + + +# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0]) +Agent_MODEL = None + +# 历史对话轮数 +HISTORY_LEN = 3 + +# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度 +MAX_TOKENS = None + +# LLM通用对话参数 +TEMPERATURE = 0.7 +# TOP_P = 0.95 # ChatOpenAI暂不支持该参数 SUPPORT_AGENT_MODELS = [ "chatglm3-6b", @@ -12,97 +27,103 @@ SUPPORT_AGENT_MODELS = [ "Qwen-14B-Chat", "Qwen-7B-Chat", ] + + LLM_MODEL_CONFIG = { +# 意图识别不需要输出,模型后台知道就行 "preprocess_model": { - # "Mixtral-8x7B-v0.1": { - # "temperature": 0.01, - # "max_tokens": 5, - # "prompt_name": "default", - # "callbacks": False - # }, - "chatglm3-6b": { + DEFAULT_LLM_MODEL: { "temperature": 0.05, "max_tokens": 4096, + "history_len": 100, "prompt_name": "default", "callbacks": False }, }, "llm_model": { - # "Mixtral-8x7B-v0.1": { - # "temperature": 0.9, - # "max_tokens": 4000, - # "history_len": 5, - # "prompt_name": "default", - # "callbacks": True - # }, - "chatglm3-6b": { - "temperature": 0.05, + DEFAULT_LLM_MODEL: { + "temperature": 0.9, "max_tokens": 4096, - "prompt_name": "default", "history_len": 10, + "prompt_name": "default", "callbacks": True }, }, "action_model": { - # "Qwen-14B-Chat": { - # "temperature": 0.05, - # "max_tokens": 4096, - # "prompt_name": "qwen", - # "callbacks": True - # }, - "chatglm3-6b": { - "temperature": 0.05, + DEFAULT_LLM_MODEL: { + "temperature": 0.01, "max_tokens": 4096, "prompt_name": "ChatGLM3", "callbacks": True }, - # "zhipu-api": { - # "temperature": 0.01, - # "max_tokens": 4096, - # "prompt_name": "ChatGLM3", - # "callbacks": True - # } - - }, + }, "postprocess_model": { - "zhipu-api": { + DEFAULT_LLM_MODEL: { "temperature": 0.01, "max_tokens": 4096, "prompt_name": "default", "callbacks": True } }, -} - - -MODEL_PATH = { - "embed_model": { - "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", - "ernie-base": "nghuyong/ernie-3.0-base-zh", - "text2vec-base": "shibing624/text2vec-base-chinese", - "text2vec": "GanymedeNil/text2vec-large-chinese", - "text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase", - "text2vec-sentence": "shibing624/text2vec-base-chinese-sentence", - "text2vec-multilingual": "shibing624/text2vec-base-multilingual", - "text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese", - "m3e-small": "moka-ai/m3e-small", - "m3e-base": "moka-ai/m3e-base", - "m3e-large": "moka-ai/m3e-large", - "bge-small-zh": "BAAI/bge-small-zh", - "bge-base-zh": "BAAI/bge-base-zh", - "bge-large-zh": "BAAI/bge-large-zh", - "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", - "bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", - "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", - "piccolo-base-zh": "sensenova/piccolo-base-zh", - "piccolo-large-zh": "sensenova/piccolo-large-zh", - "nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large", - "text-embedding-ada-002": "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", - } + "image_model": { + "sd-turbo": { + "size": "256*256", + } + }, + "multimodal_model": { + "qwen-vl": {} + }, } -NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") - -LOOM_CONFIG = "./loom.yaml" -OPENAI_KEY = None -OPENAI_PROXY = None +# 可以通过 loom/xinference/oneapi/fatchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。 +MODEL_PLATFORMS = [ + { + "platform_name": "openai-api", + "platform_type": "openai", + "llm_models": [ + "gpt-3.5-turbo", + ], + "api_base_url": "https://api.openai.com/v1", + "api_key": "sk-", + "api_proxy": "", + }, + + { + "platform_name": "xinference", + "platform_type": "xinference", + "llm_models": [ + "chatglm3-6b", + ], + "embed_models": [ + "bge-large-zh-v1.5", + ], + "image_models": [ + "sd-turbo", + ], + "multimodal_models": [ + "qwen-vl", + ], + "api_base_url": "http://127.0.0.1:9997/v1", + "api_key": "EMPTY", + }, + + { + "platform_name": "oneapi", + "platform_type": "oneapi", + "api_key": "", + "llm_models": [ + "chatglm3-6b", + ], + }, + + { + "platform_name": "loom", + "platform_type": "loom", + "api_key": "", + "llm_models": [ + "chatglm3-6b", + ], + }, +] + +LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml") diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index 8101b20b..2b2967c4 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -98,7 +98,25 @@ PROMPT_TEMPLATES = { 'Begin!\n\n' 'Question: {input}\n\n' '{agent_scratchpad}\n\n', - + "structured-chat-agent": + 'Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n' + '{tools}\n\n' + 'Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\n' + 'Valid "action" values: "Final Answer" or {tool_names}\n\n' + 'Provide only ONE action per $JSON_BLOB, as shown:\n\n' + '```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\n' + 'Follow this format:\n\n' + 'Question: input question to answer\n' + 'Thought: consider previous and subsequent steps\n' + 'Action:\n```\n$JSON_BLOB\n```\n' + 'Observation: action result\n' + '... (repeat Thought/Action/Observation N times)\n' + 'Thought: I know what to respond\n' + 'Action:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\n' + 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation\n' + '{input}\n\n' + '{agent_scratchpad}\n\n' + # '(reminder to respond in a JSON blob no matter what)' }, "postprocess_model": { "default": "{{input}}", @@ -130,7 +148,7 @@ TOOL_CONFIG = { "bing": { "result_len": 3, "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", - "bing_key": "680a39347d7242c5bd2d7a9576a125b7", + "bing_key": "", }, "metaphor": { "result_len": 3, @@ -184,4 +202,8 @@ TOOL_CONFIG = { "device": "cuda:2" }, + "text2images": { + "use": False, + }, + } diff --git a/configs/server_config.py.example b/configs/server_config.py.example index cf9cb44f..90182036 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -1,5 +1,5 @@ import sys -from configs.model_config import LLM_DEVICE + # httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 HTTPX_DEFAULT_TIMEOUT = 300.0 @@ -11,10 +11,11 @@ OPEN_CROSS_DOMAIN = True # 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1" + # webui.py server WEBUI_SERVER = { "host": DEFAULT_BIND_HOST, - "port": 7870, + "port": 8501, } # api.py server diff --git a/embeddings/__init__.py b/embeddings/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/embeddings/add_embedding_keywords.py b/embeddings/add_embedding_keywords.py deleted file mode 100644 index f46dee29..00000000 --- a/embeddings/add_embedding_keywords.py +++ /dev/null @@ -1,79 +0,0 @@ -''' -该功能是为了将关键词加入到embedding模型中,以便于在embedding模型中进行关键词的embedding -该功能的实现是通过修改embedding模型的tokenizer来实现的 -该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效,输出后的模型保存在原本模型 -感谢@CharlesJu1和@charlesyju的贡献提出了想法和最基础的PR - -保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳 -''' -import sys - -sys.path.append("..") -import os -import torch - -from datetime import datetime -from configs import ( - MODEL_PATH, - EMBEDDING_MODEL, - EMBEDDING_KEYWORD_FILE, -) - -from safetensors.torch import save_model -from sentence_transformers import SentenceTransformer -from langchain_core._api import deprecated - - -@deprecated( - since="0.3.0", - message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃", - removal="0.3.0" - ) -def get_keyword_embedding(bert_model, tokenizer, key_words): - tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True) - input_ids = tokenizer_output['input_ids'] - input_ids = input_ids[:, 1:-1] - - keyword_embedding = bert_model.embeddings.word_embeddings(input_ids) - keyword_embedding = torch.mean(keyword_embedding, 1) - return keyword_embedding - - -def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None): - key_words = [] - with open(keyword_file, "r") as f: - for line in f: - key_words.append(line.strip()) - - st_model = SentenceTransformer(model_name) - key_words_len = len(key_words) - word_embedding_model = st_model._first_module() - bert_model = word_embedding_model.auto_model - tokenizer = word_embedding_model.tokenizer - key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words) - - embedding_weight = bert_model.embeddings.word_embeddings.weight - embedding_weight_len = len(embedding_weight) - tokenizer.add_tokens(key_words) - bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) - embedding_weight = bert_model.embeddings.word_embeddings.weight - with torch.no_grad(): - embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding - - if output_model_path: - os.makedirs(output_model_path, exist_ok=True) - word_embedding_model.save(output_model_path) - safetensors_file = os.path.join(output_model_path, "model.safetensors") - metadata = {'format': 'pt'} - save_model(bert_model, safetensors_file, metadata) - print("save model to {}".format(output_model_path)) - - -def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE): - keyword_file = os.path.join(path) - model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL] - model_parent_directory = os.path.dirname(model_name) - current_time = datetime.now().strftime('%Y%m%d_%H%M%S') - output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time) - output_model_path = os.path.join(model_parent_directory, output_model_name) - add_keyword_to_model(model_name, keyword_file, output_model_path) diff --git a/embeddings/embedding_keywords.txt b/embeddings/embedding_keywords.txt deleted file mode 100644 index 3822b992..00000000 --- a/embeddings/embedding_keywords.txt +++ /dev/null @@ -1,3 +0,0 @@ -Langchain-Chatchat -数据科学与大数据技术 -人工智能与先进计算 \ No newline at end of file diff --git a/init_database.py b/init_database.py index a394b758..1ca0fa60 100644 --- a/init_database.py +++ b/init_database.py @@ -2,9 +2,7 @@ import sys sys.path.append(".") from server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, folder2db, prune_db_docs, prune_folder_files) -from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL -import nltk -nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path +from configs.model_config import NLTK_DATA_PATH, DEFAULT_EMBEDDING_MODEL from datetime import datetime @@ -19,7 +17,7 @@ if __name__ == "__main__": action="store_true", help=(''' recreate vector store. - use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed. + use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/DEFAULT_EMBEDDING_MODEL changed. ''' ) ) @@ -87,7 +85,7 @@ if __name__ == "__main__": "-e", "--embed-model", type=str, - default=EMBEDDING_MODEL, + default=DEFAULT_EMBEDDING_MODEL, help=("specify embeddings model.") ) diff --git a/requirements_api.txt b/requirements_api.txt deleted file mode 100644 index 801cea1d..00000000 --- a/requirements_api.txt +++ /dev/null @@ -1,57 +0,0 @@ -# API requirements - -langchain>=0.0.350 -langchain-experimental>=0.0.42 -pydantic==1.10.13 -fschat==0.2.35 -openai~=1.9.0 -fastapi~=0.109.0 -sse_starlette==1.8.2 -nltk>=3.8.1 -uvicorn>=0.27.0.post1 -starlette~=0.35.0 -unstructured[all-docs]==0.11.0 -python-magic-bin; sys_platform == 'win32' -SQLAlchemy==2.0.19 -faiss-cpu~=1.7.4 -accelerate~=0.24.1 -spacy~=3.7.2 -PyMuPDF~=1.23.8 -rapidocr_onnxruntime==1.3.8 -requests~=2.31.0 -pathlib~=1.0.1 -pytest~=7.4.3 -numexpr~=2.8.6 -strsimpy~=0.2.1 -markdownify~=0.11.6 -tiktoken~=0.5.2 -tqdm>=4.66.1 -websockets>=12.0 -numpy~=1.24.4 -pandas~=2.0.3 -einops>=0.7.0 -transformers_stream_generator==0.0.4 -vllm==0.2.7; sys_platform == "linux" -httpx==0.26.0 -httpx_sse==0.4.0 -llama-index==0.9.35 -pyjwt==2.8.0 - -# jq==1.6.0 -# beautifulsoup4~=4.12.2 -# pysrt~=1.1.2 -# dashscope==1.13.6 -# arxiv~=2.1.0 -# youtube-search~=2.1.2 -# duckduckgo-search~=3.9.9 -# metaphor-python~=0.1.23 - -# volcengine>=1.0.119 -# pymilvus==2.3.6 -# psycopg2==2.9.9 -# pgvector>=0.2.4 -# chromadb==0.4.13 - -#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat -#autoawq==0.1.8 # For Int4 -#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files \ No newline at end of file diff --git a/requirements_lite.txt b/requirements_lite.txt deleted file mode 100644 index dc420c62..00000000 --- a/requirements_lite.txt +++ /dev/null @@ -1,42 +0,0 @@ -langchain==0.0.354 -langchain-experimental==0.0.47 -pydantic==1.10.13 -fschat~=0.2.35 -openai~=1.9.0 -fastapi~=0.109.0 -sse_starlette~=1.8.2 -nltk~=3.8.1 -uvicorn>=0.27.0.post1 -starlette~=0.35.0 -unstructured[all-docs]~=0.12.0 -python-magic-bin; sys_platform == 'win32' -SQLAlchemy~=2.0.25 -faiss-cpu~=1.7.4 -accelerate~=0.24.1 -spacy~=3.7.2 -PyMuPDF~=1.23.16 -rapidocr_onnxruntime~=1.3.8 -requests~=2.31.0 -pathlib~=1.0.1 -pytest~=7.4.3 -llama-index==0.9.35 -pyjwt==2.8.0 -httpx==0.26.0 -httpx_sse==0.4.0 - -dashscope==1.13.6 -arxiv~=2.1.0 -youtube-search~=2.1.2 -duckduckgo-search~=3.9.9 -metaphor-python~=0.1.23 -watchdog~=3.0.0 -# volcengine>=1.0.119 -# pymilvus>=2.3.4 -# psycopg2==2.9.9 -# pgvector>=0.2.4 -# chromadb==0.4.13 - -# jq==1.6.0 -# beautifulsoup4~=4.12.2 -# pysrt~=1.1.2 - diff --git a/requirements_webui.txt b/requirements_webui.txt deleted file mode 100644 index a993c513..00000000 --- a/requirements_webui.txt +++ /dev/null @@ -1,9 +0,0 @@ -streamlit==1.30.0 -streamlit-option-menu==0.3.12 -streamlit-antd-components==0.3.1 -streamlit-chatbox==1.1.11 -streamlit-modal==0.1.0 -streamlit-aggrid==0.3.4.post3 -httpx==0.26.0 -httpx_sse==0.4.0 -watchdog=s=3.0.0 \ No newline at end of file diff --git a/server/agent/tools_factory/text2image.py b/server/agent/tools_factory/text2image.py index 99d9e2a5..c6983666 100644 --- a/server/agent/tools_factory/text2image.py +++ b/server/agent/tools_factory/text2image.py @@ -6,7 +6,7 @@ from typing import List import uuid from langchain.agents import tool -from pydantic.v1 import BaseModel, Field +from langchain.pydantic_v1 import Field import openai from pydantic.fields import FieldInfo diff --git a/server/api.py b/server/api.py index 92e595a6..7c02b17a 100644 --- a/server/api.py +++ b/server/api.py @@ -1,13 +1,9 @@ -import nltk import sys import os -from server.knowledge_base.kb_doc_api import update_kb_endpoint - sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs import VERSION, MEDIA_PATH -from configs.model_config import NLTK_DATA_PATH from configs.server_config import OPEN_CROSS_DOMAIN import argparse import uvicorn @@ -18,14 +14,11 @@ from starlette.responses import RedirectResponse from server.chat.chat import chat from server.chat.completion import completion from server.chat.feedback import chat_feedback -from server.embeddings.core.embeddings_api import embed_texts_endpoint - +from server.knowledge_base.model.kb_document_model import DocumentWithVSId from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs, get_prompt_template) from typing import List, Literal -nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path - async def document(): return RedirectResponse(url="/docs") @@ -95,11 +88,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): summary="要求llm模型补全(通过LLMChain)", )(completion) - app.post("/other/embed_texts", - tags=["Other"], - summary="将文本向量化,支持本地模型和在线模型", - )(embed_texts_endpoint) - # 媒体文件 app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") @@ -109,8 +97,7 @@ def mount_knowledge_routes(app: FastAPI): from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, update_docs, download_doc, recreate_vector_store, - search_docs, DocumentWithVSId, update_info, - update_docs_by_id,) + search_docs, update_info) app.post("/chat/file_chat", tags=["Knowledge Base Management"], @@ -146,13 +133,6 @@ def mount_knowledge_routes(app: FastAPI): summary="搜索知识库" )(search_docs) - app.post("/knowledge_base/update_docs_by_id", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="直接更新知识库文档" - )(update_docs_by_id) - - app.post("/knowledge_base/upload_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, @@ -171,12 +151,6 @@ def mount_knowledge_routes(app: FastAPI): summary="更新知识库介绍" )(update_info) - app.post("/knowledge_base/update_kb_endpoint", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="更新知识库在线api接入点配置" - )(update_kb_endpoint) - app.post("/knowledge_base/update_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, diff --git a/server/chat/chat.py b/server/chat/chat.py index d80a1c98..66dfd582 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import AsyncIterable, List, Union, Dict, Annotated +from typing import AsyncIterable, List from fastapi import Body from fastapi.responses import StreamingResponse @@ -21,7 +21,7 @@ from server.db.repository import add_message_to_db from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus -def create_models_from_config(configs, openai_config, callbacks, stream): +def create_models_from_config(configs, callbacks, stream): if configs is None: configs = {} models = {} @@ -30,9 +30,6 @@ def create_models_from_config(configs, openai_config, callbacks, stream): for model_name, params in model_configs.items(): callbacks = callbacks if params.get('callbacks', False) else None model_instance = get_ChatOpenAI( - endpoint_host=openai_config.get('endpoint_host', None), - endpoint_host_key=openai_config.get('endpoint_host_key', None), - endpoint_host_proxy=openai_config.get('endpoint_host_proxy', None), model_name=model_name, temperature=params.get('temperature', 0.5), max_tokens=params.get('max_tokens', 1000), @@ -116,7 +113,6 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 ), stream: bool = Body(True, description="流式输出"), chat_model_config: dict = Body({}, description="LLM 模型配置", examples=[]), - openai_config: dict = Body({}, description="openaiEndpoint配置", examples=[]), tool_config: dict = Body({}, description="工具配置", examples=[]), ): async def chat_iterator() -> AsyncIterable[str]: @@ -129,7 +125,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callback = AgentExecutorAsyncIteratorCallbackHandler() callbacks = [callback] models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config, - openai_config=openai_config, stream=stream) + stream=stream) tools = [tool for tool in all_tools if tool.name in tool_config] tools = [t.copy(update={"callbacks": callbacks}) for t in tools] full_chain = create_models_chains(prompts=prompts, diff --git a/server/chat/completion.py b/server/chat/completion.py index 559f10bb..05b45740 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -1,6 +1,5 @@ from fastapi import Body -from fastapi.responses import StreamingResponse -from configs import LLM_MODEL_CONFIG +from sse_starlette.sse import EventSourceResponse from server.utils import wrap_done, get_OpenAI from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler @@ -14,9 +13,6 @@ from server.utils import get_prompt_template async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), stream: bool = Body(False, description="流式输出"), echo: bool = Body(False, description="除了输出之外,还回显输入"), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -27,9 +23,6 @@ async def completion(query: str = Body(..., description="用户输入", examples #TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 async def completion_iterator(query: str, - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, model_name: str = None, prompt_name: str = prompt_name, echo: bool = echo, @@ -40,9 +33,6 @@ async def completion(query: str = Body(..., description="用户输入", examples max_tokens = None model = get_OpenAI( - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, @@ -72,10 +62,7 @@ async def completion(query: str = Body(..., description="用户输入", examples await task - return StreamingResponse(completion_iterator(query=query, - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, + return EventSourceResponse(completion_iterator(query=query, model_name=model_name, prompt_name=prompt_name), ) diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index 5fac6f0a..f9b31714 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -1,8 +1,7 @@ from fastapi import Body, File, Form, UploadFile from fastapi.responses import StreamingResponse from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) -from server.embeddings.adapter import load_temp_adapter_embeddings -from server.utils import (wrap_done, get_ChatOpenAI, +from server.utils import (wrap_done, get_ChatOpenAI, get_Embeddings, BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool) from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool from langchain.chains import LLMChain @@ -57,9 +56,6 @@ def _parse_files_in_thread( def upload_temp_docs( - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), files: List[UploadFile] = File(..., description="上传文件,支持多文件"), prev_id: str = Form(None, description="前知识库ID"), chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), @@ -86,11 +82,7 @@ def upload_temp_docs( else: failed_files.append({file: msg}) - with memo_faiss_pool.load_vector_store(kb_name=id, - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, - ).acquire() as vs: + with memo_faiss_pool.load_vector_store(kb_name=id).acquire() as vs: vs.add_documents(documents) return BaseResponse(data={"id": id, "failed_files": failed_files}) @@ -110,9 +102,6 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -131,17 +120,12 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= max_tokens = None model = get_ChatOpenAI( - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) - embed_func = load_temp_adapter_embeddings(endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy) + embed_func = get_Embeddings() embeddings = await embed_func.aembed_query(query) with memo_faiss_pool.acquire(knowledge_id) as vs: docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) diff --git a/server/db/models/knowledge_base_model.py b/server/db/models/knowledge_base_model.py index e5cd0809..f9035af4 100644 --- a/server/db/models/knowledge_base_model.py +++ b/server/db/models/knowledge_base_model.py @@ -12,9 +12,6 @@ class KnowledgeBaseModel(Base): kb_name = Column(String(50), comment='知识库名称') kb_info = Column(String(200), comment='知识库简介(用于Agent)') vs_type = Column(String(50), comment='向量库类型') - endpoint_host = Column(String(50), comment='接入点地址') - endpoint_host_key = Column(String(50), comment='接入点key') - endpoint_host_proxy = Column(String(50), comment='接入点代理地址') embed_model = Column(String(50), comment='嵌入模型名称') file_count = Column(Integer, default=0, comment='文件数量') create_time = Column(DateTime, default=func.now(), comment='创建时间') diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py index d241328e..b39c8c57 100644 --- a/server/db/repository/knowledge_base_repository.py +++ b/server/db/repository/knowledge_base_repository.py @@ -3,22 +3,16 @@ from server.db.session import with_session @with_session -def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model, endpoint_host: str = None, - endpoint_host_key: str = None, endpoint_host_proxy: str = None): +def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): # 创建知识库实例 kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() if not kb: - kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model, - endpoint_host=endpoint_host, endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy) + kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model) session.add(kb) else: # update kb with new vs_type and embed_model kb.kb_info = kb_info kb.vs_type = vs_type kb.embed_model = embed_model - kb.endpoint_host = endpoint_host - kb.endpoint_host_key = endpoint_host_key - kb.endpoint_host_proxy = endpoint_host_proxy return True @@ -54,16 +48,6 @@ def delete_kb_from_db(session, kb_name): return True -@with_session -def update_kb_endpoint_from_db(session, kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy): - kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first() - if kb: - kb.endpoint_host = endpoint_host - kb.endpoint_host_key = endpoint_host_key - kb.endpoint_host_proxy = endpoint_host_proxy - return True - - @with_session def get_kb_detail(session, kb_name: str) -> dict: kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first() @@ -72,9 +56,6 @@ def get_kb_detail(session, kb_name: str) -> dict: "kb_name": kb.kb_name, "kb_info": kb.kb_info, "vs_type": kb.vs_type, - "endpoint_host": kb.endpoint_host, - "endpoint_host_key": kb.endpoint_host_key, - "endpoint_host_proxy": kb.endpoint_host_proxy, "embed_model": kb.embed_model, "file_count": kb.file_count, "create_time": kb.create_time, diff --git a/server/embeddings/__init__.py b/server/embeddings/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server/embeddings/adapter.py b/server/embeddings/adapter.py deleted file mode 100644 index 50ccdf04..00000000 --- a/server/embeddings/adapter.py +++ /dev/null @@ -1,130 +0,0 @@ -import numpy as np -from typing import List, Union, Dict -from langchain.embeddings.base import Embeddings -from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - EMBEDDING_MODEL, KB_INFO) -from server.embeddings.core.embeddings_api import embed_texts, aembed_texts -from server.utils import embedding_device - - -class EmbeddingsFunAdapter(Embeddings): - _endpoint_host: str - _endpoint_host_key: str - _endpoint_host_proxy: str - - def __init__(self, - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, - embed_model: str = EMBEDDING_MODEL, - ): - self._endpoint_host = endpoint_host - self._endpoint_host_key = endpoint_host_key - self._endpoint_host_proxy = endpoint_host_proxy - self.embed_model = embed_model - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - embeddings = embed_texts(texts=texts, - endpoint_host=self._endpoint_host, - endpoint_host_key=self._endpoint_host_key, - endpoint_host_proxy=self._endpoint_host_proxy, - embed_model=self.embed_model, - to_query=False).data - return self._normalize(embeddings=embeddings).tolist() - - def embed_query(self, text: str) -> List[float]: - embeddings = embed_texts(texts=[text], - endpoint_host=self._endpoint_host, - endpoint_host_key=self._endpoint_host_key, - endpoint_host_proxy=self._endpoint_host_proxy, - embed_model=self.embed_model, - to_query=True).data - query_embed = embeddings[0] - query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 - normalized_query_embed = self._normalize(embeddings=query_embed_2d) - return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 - - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: - embeddings = (await aembed_texts(texts=texts, - endpoint_host=self._endpoint_host, - endpoint_host_key=self._endpoint_host_key, - endpoint_host_proxy=self._endpoint_host_proxy, - embed_model=self.embed_model, - to_query=False)).data - return self._normalize(embeddings=embeddings).tolist() - - async def aembed_query(self, text: str) -> List[float]: - embeddings = (await aembed_texts(texts=[text], - endpoint_host=self._endpoint_host, - endpoint_host_key=self._endpoint_host_key, - endpoint_host_proxy=self._endpoint_host_proxy, - embed_model=self.embed_model, - to_query=True)).data - query_embed = embeddings[0] - query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 - normalized_query_embed = self._normalize(embeddings=query_embed_2d) - return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 - - @staticmethod - def _normalize(embeddings: List[List[float]]) -> np.ndarray: - ''' - sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn - #TODO 此处内容处理错误 - ''' - norm = np.linalg.norm(embeddings, axis=1) - norm = np.reshape(norm, (norm.shape[0], 1)) - norm = np.tile(norm, (1, len(embeddings[0]))) - return np.divide(embeddings, norm) - - -def load_kb_adapter_embeddings( - kb_name: str, - embed_device: str = embedding_device(), - default_embed_model: str = EMBEDDING_MODEL, -) -> "EmbeddingsFunAdapter": - """ - 加载知识库配置的Embeddings模型 - 本地模型最终会通过load_embeddings加载 - 在线模型会在适配器中直接返回 - :param kb_name: - :param embed_device: - :param default_embed_model: - :return: - """ - from server.db.repository.knowledge_base_repository import get_kb_detail - - kb_detail = get_kb_detail(kb_name) - embed_model = kb_detail.get("embed_model", default_embed_model) - endpoint_host = kb_detail.get("endpoint_host", None) - endpoint_host_key = kb_detail.get("endpoint_host_key", None) - endpoint_host_proxy = kb_detail.get("endpoint_host_proxy", None) - - return EmbeddingsFunAdapter(endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, - embed_model=embed_model) - - -def load_temp_adapter_embeddings( - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, - embed_device: str = embedding_device(), - default_embed_model: str = EMBEDDING_MODEL, -) -> "EmbeddingsFunAdapter": - """ - 加载临时的Embeddings模型 - 本地模型最终会通过load_embeddings加载 - 在线模型会在适配器中直接返回 - :param endpoint_host: - :param endpoint_host_key: - :param endpoint_host_proxy: - :param embed_device: - :param default_embed_model: - :return: - """ - - return EmbeddingsFunAdapter(endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, - embed_model=default_embed_model) diff --git a/server/embeddings/core/__init__.py b/server/embeddings/core/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server/embeddings/core/embeddings_api.py b/server/embeddings/core/embeddings_api.py deleted file mode 100644 index 5ac5d6b7..00000000 --- a/server/embeddings/core/embeddings_api.py +++ /dev/null @@ -1,94 +0,0 @@ -from langchain.docstore.document import Document -from configs import EMBEDDING_MODEL, logger, CHUNK_SIZE -from server.utils import BaseResponse, list_embed_models, list_online_embed_models -from fastapi import Body -from fastapi.concurrency import run_in_threadpool -from typing import Dict, List - - -def embed_texts( - texts: List[str], - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, - embed_model: str = EMBEDDING_MODEL, - to_query: bool = False, -) -> BaseResponse: - ''' - 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) - TODO: 也许需要加入缓存机制,减少 token 消耗 - ''' - try: - if embed_model in list_embed_models(): # 使用本地Embeddings模型 - from server.utils import load_local_embeddings - embeddings = load_local_embeddings(model=embed_model) - return BaseResponse(data=embeddings.embed_documents(texts)) - - # 使用在线API - if embed_model in list_online_embed_models(endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy): - from langchain.embeddings.openai import OpenAIEmbeddings - embeddings = OpenAIEmbeddings(model=embed_model, - openai_api_key=endpoint_host_key if endpoint_host_key else "None", - openai_api_base=endpoint_host if endpoint_host else "None", - openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None, - chunk_size=CHUNK_SIZE) - return BaseResponse(data=embeddings.embed_documents(texts)) - - return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。") - except Exception as e: - logger.error(e) - return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") - - -async def aembed_texts( - texts: List[str], - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, - embed_model: str = EMBEDDING_MODEL, - to_query: bool = False, -) -> BaseResponse: - ''' - 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) - ''' - try: - if embed_model in list_embed_models(): # 使用本地Embeddings模型 - from server.utils import load_local_embeddings - - embeddings = load_local_embeddings(model=embed_model) - return BaseResponse(data=await embeddings.aembed_documents(texts)) - - # 使用在线API - if embed_model in list_online_embed_models(endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy): - return await run_in_threadpool(embed_texts, - texts=texts, - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, - embed_model=embed_model, - to_query=to_query) - except Exception as e: - logger.error(e) - return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") - - -def embed_texts_endpoint( - texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), - embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型"), - to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"), -) -> BaseResponse: - ''' - 接入api,对文本进行向量化,返回 BaseResponse(data=List[List[float]]) - ''' - return embed_texts(texts=texts, - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, - embed_model=embed_model, to_query=to_query) diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index cab28d00..a4850556 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_base_repository import list_kbs_from_db -from configs import EMBEDDING_MODEL, logger, log_verbose +from configs import DEFAULT_EMBEDDING_MODEL, logger, log_verbose from fastapi import Body @@ -14,10 +14,7 @@ def list_kbs(): def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), vector_store_type: str = Body("faiss"), - embed_model: str = Body(EMBEDDING_MODEL), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), ) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): @@ -31,7 +28,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) try: - kb.create_kb(endpoint_host, endpoint_host_key, endpoint_host_proxy) + kb.create_kb() except Exception as e: msg = f"创建知识库出错: {e}" logger.error(f'{e.__class__.__name__}: {msg}', diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py index 96559b86..ab61099a 100644 --- a/server/knowledge_base/kb_cache/base.py +++ b/server/knowledge_base/kb_cache/base.py @@ -1,9 +1,8 @@ from langchain.embeddings.base import Embeddings from langchain.vectorstores.faiss import FAISS import threading -from configs import (EMBEDDING_MODEL, CHUNK_SIZE, +from configs import (DEFAULT_EMBEDDING_MODEL, CHUNK_SIZE, logger, log_verbose) -from server.utils import embedding_device, get_model_path from contextlib import contextmanager from collections import OrderedDict from typing import List, Any, Union, Tuple @@ -98,50 +97,3 @@ class CachePool: else: return cache - -class EmbeddingsPool(CachePool): - - def load_embeddings(self, model: str = None, device: str = None) -> Embeddings: - """ - 本地Embeddings模型加载 - :param model: - :param device: - :return: - """ - self.atomic.acquire() - model = model or EMBEDDING_MODEL - device = embedding_device() - key = (model, device) - if not self.get(key): - item = ThreadSafeObject(key, pool=self) - self.set(key, item) - with item.acquire(msg="初始化"): - self.atomic.release() - if 'bge-' in model: - from langchain.embeddings import HuggingFaceBgeEmbeddings - if 'zh' in model: - # for chinese model - query_instruction = "为这个句子生成表示以用于检索相关文章:" - elif 'en' in model: - # for english model - query_instruction = "Represent this sentence for searching relevant passages:" - else: - # maybe ReRanker or else, just use empty string instead - query_instruction = "" - embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model), - model_kwargs={'device': device}, - query_instruction=query_instruction) - if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding - embeddings.query_instruction = "" - else: - from langchain.embeddings.huggingface import HuggingFaceEmbeddings - embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), - model_kwargs={'device': device}) - item.obj = embeddings - item.finish_loading() - else: - self.atomic.release() - return self.get(key).obj - - -embeddings_pool = EmbeddingsPool(cache_num=1) diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 86d3348e..067ba37a 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -1,7 +1,6 @@ from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM -from server.embeddings.adapter import load_kb_adapter_embeddings, load_temp_adapter_embeddings from server.knowledge_base.kb_cache.base import * -# from server.utils import load_local_embeddings +from server.utils import get_Embeddings from server.knowledge_base.utils import get_vs_path from langchain.vectorstores.faiss import FAISS from langchain.docstore.in_memory import InMemoryDocstore @@ -53,13 +52,11 @@ class _FaissPool(CachePool): def new_vector_store( self, kb_name: str, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> FAISS: # create an empty vector store - embeddings = load_kb_adapter_embeddings(kb_name=kb_name, - embed_device=embed_device, default_embed_model=embed_model) + embeddings = get_Embeddings(embed_model=embed_model) doc = Document(page_content="init", metadata={}) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") ids = list(vector_store.docstore._dict.keys()) @@ -68,18 +65,11 @@ class _FaissPool(CachePool): def new_temp_vector_store( self, - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> FAISS: # create an empty vector store - embeddings = load_temp_adapter_embeddings(endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, - embed_device=embed_device, default_embed_model=embed_model) + embeddings = get_Embeddings(embed_model=embed_model) doc = Document(page_content="init", metadata={}) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) ids = list(vector_store.docstore._dict.keys()) @@ -102,8 +92,7 @@ class KBFaissPool(_FaissPool): kb_name: str, vector_name: str = None, create: bool = True, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> ThreadSafeFaiss: self.atomic.acquire() vector_name = vector_name or embed_model @@ -118,15 +107,13 @@ class KBFaissPool(_FaissPool): vs_path = get_vs_path(kb_name, vector_name) if os.path.isfile(os.path.join(vs_path, "index.faiss")): - embeddings = load_kb_adapter_embeddings(kb_name=kb_name, - embed_device=embed_device, default_embed_model=embed_model) + embeddings = get_Embeddings(embed_model=embed_model) vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) elif create: # create an empty vector store if not os.path.exists(vs_path): os.makedirs(vs_path) - vector_store = self.new_vector_store(kb_name=kb_name, - embed_model=embed_model, embed_device=embed_device) + vector_store = self.new_vector_store(kb_name=kb_name, embed_model=embed_model) vector_store.save_local(vs_path) else: raise RuntimeError(f"knowledge base {kb_name} not exist.") @@ -148,11 +135,7 @@ class MemoFaissPool(_FaissPool): def load_vector_store( self, kb_name: str, - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> ThreadSafeFaiss: self.atomic.acquire() cache = self.get(kb_name) @@ -163,10 +146,7 @@ class MemoFaissPool(_FaissPool): self.atomic.release() logger.info(f"loading vector store in '{kb_name}' to memory.") # create an empty vector store - vector_store = self.new_temp_vector_store(endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, - embed_model=embed_model, embed_device=embed_device) + vector_store = self.new_temp_vector_store(embed_model=embed_model) item.obj = vector_store item.finish_loading() else: diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index ceb4deb7..e9a22b21 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,7 +1,7 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile -from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, +from configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, logger, log_verbose, ) @@ -42,22 +42,6 @@ def search_docs( return data -def update_docs_by_id( - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}") -) -> BaseResponse: - ''' - 按照文档 ID 更新文档内容 - ''' - kb = KBServiceFactory.get_service_by_name(knowledge_base_name) - if kb is None: - return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在") - if kb.update_doc_by_ids(docs=docs): - return BaseResponse(msg=f"文档更新成功") - else: - return BaseResponse(msg=f"文档更新失败") - - def list_files( knowledge_base_name: str ) -> ListResponse: @@ -230,26 +214,6 @@ def update_info( return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info}) -def update_kb_endpoint( - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), -): - if not validate_kb_name(knowledge_base_name): - return BaseResponse(code=403, msg="Don't attack me") - - kb = KBServiceFactory.get_service_by_name(knowledge_base_name) - if kb is None: - return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb.update_kb_endpoint(endpoint_host, endpoint_host_key, endpoint_host_proxy) - - return BaseResponse(code=200, msg=f"知识库在线api接入点配置修改完成", - data={"endpoint_host": endpoint_host, - "endpoint_host_key": endpoint_host_key, - "endpoint_host_proxy": endpoint_host_proxy}) - - def update_docs( knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]), @@ -366,10 +330,7 @@ def recreate_vector_store( knowledge_base_name: str = Body(..., examples=["samples"]), allow_empty_kb: bool = Body(True), vs_type: str = Body(DEFAULT_VS_TYPE), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), - embed_model: str = Body(EMBEDDING_MODEL), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), @@ -389,9 +350,7 @@ def recreate_vector_store( else: if kb.exists(): kb.clear_vs() - kb.create_kb(endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy) + kb.create_kb() files = list_files_from_folder(knowledge_base_name) kb_files = [(file, knowledge_base_name) for file in files] i = 0 diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 336f4b85..02962527 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -6,7 +6,7 @@ from langchain.docstore.document import Document from server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, - load_kb_from_db, get_kb_detail, update_kb_endpoint_from_db, + load_kb_from_db, get_kb_detail, ) from server.db.repository.knowledge_file_repository import ( add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, @@ -15,7 +15,7 @@ from server.db.repository.knowledge_file_repository import ( ) from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - EMBEDDING_MODEL, KB_INFO) + DEFAULT_EMBEDDING_MODEL, KB_INFO) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, @@ -40,7 +40,7 @@ class KBService(ABC): def __init__(self, knowledge_base_name: str, - embed_model: str = EMBEDDING_MODEL, + embed_model: str = DEFAULT_EMBEDDING_MODEL, ): self.kb_name = knowledge_base_name self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") @@ -58,20 +58,14 @@ class KBService(ABC): ''' pass - def create_kb(self, - endpoint_host: str = None, - endpoint_host_key: str = None, - endpoint_host_proxy: str = None): + def create_kb(self): """ 创建知识库 """ if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) - status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model, - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy) + status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) if status: self.do_create_kb() @@ -144,16 +138,6 @@ class KBService(ABC): status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) return status - def update_kb_endpoint(self, - endpoint_host: str = None, - endpoint_host_key: str = None, - endpoint_host_proxy: str = None): - """ - 更新知识库在线api接入点配置 - """ - status = update_kb_endpoint_from_db(self.kb_name, endpoint_host, endpoint_host_key, endpoint_host_proxy) - return status - def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 使用content中的文件更新向量库 @@ -297,7 +281,7 @@ class KBServiceFactory: @staticmethod def get_service(kb_name: str, vector_store_type: Union[str, SupportedVSType], - embed_model: str = EMBEDDING_MODEL, + embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> KBService: if isinstance(vector_store_type, str): vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) diff --git a/server/knowledge_base/kb_service/es_kb_service.py b/server/knowledge_base/kb_service/es_kb_service.py index afdfd70c..327633ef 100644 --- a/server/knowledge_base/kb_service/es_kb_service.py +++ b/server/knowledge_base/kb_service/es_kb_service.py @@ -1,13 +1,11 @@ from typing import List import os import shutil -from langchain.embeddings.base import Embeddings from langchain.schema import Document from langchain.vectorstores.elasticsearch import ElasticsearchStore -from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.utils import KnowledgeFile -from server.utils import load_local_embeddings +from server.utils import get_Embeddings from elasticsearch import Elasticsearch,BadRequestError from configs import logger from configs import kbs_config @@ -22,7 +20,7 @@ class ESKBService(KBService): self.user = kbs_config[self.vs_type()].get("user",'') self.password = kbs_config[self.vs_type()].get("password",'') self.dims_length = kbs_config[self.vs_type()].get("dims_length",None) - self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE) + self.embeddings_model = get_Embeddings(self.embed_model) try: # ES python客户端连接(仅连接) if self.user != "" and self.password != "": diff --git a/server/knowledge_base/kb_summary/base.py b/server/knowledge_base/kb_summary/base.py index 6d095fee..dae3dc9e 100644 --- a/server/knowledge_base/kb_summary/base.py +++ b/server/knowledge_base/kb_summary/base.py @@ -1,7 +1,7 @@ from typing import List from configs import ( - EMBEDDING_MODEL, + DEFAULT_EMBEDDING_MODEL, KB_ROOT_PATH) from abc import ABC, abstractmethod @@ -21,7 +21,7 @@ class KBSummaryService(ABC): def __init__(self, knowledge_base_name: str, - embed_model: str = EMBEDDING_MODEL + embed_model: str = DEFAULT_EMBEDDING_MODEL ): self.kb_name = knowledge_base_name self.embed_model = embed_model diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py index 674fde11..86263e8e 100644 --- a/server/knowledge_base/kb_summary_api.py +++ b/server/knowledge_base/kb_summary_api.py @@ -1,5 +1,5 @@ from fastapi import Body -from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, +from configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL, OVERLAP_SIZE, logger, log_verbose, ) from server.knowledge_base.utils import (list_files_from_folder) @@ -17,11 +17,8 @@ def recreate_summary_vector_store( knowledge_base_name: str = Body(..., examples=["samples"]), allow_empty_kb: bool = Body(True), vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(EMBEDDING_MODEL), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), file_description: str = Body(''), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -29,9 +26,6 @@ def recreate_summary_vector_store( """ 重建单个知识库文件摘要 :param max_tokens: - :param endpoint_host: - :param endpoint_host_key: - :param endpoint_host_proxy: :param model_name: :param temperature: :param file_description: @@ -54,17 +48,11 @@ def recreate_summary_vector_store( kb_summary.create_kb_summary() llm = get_ChatOpenAI( - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, ) reduce_llm = get_ChatOpenAI( - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, @@ -110,20 +98,14 @@ def summary_file_to_vector_store( file_name: str = Body(..., examples=["test.pdf"]), allow_empty_kb: bool = Body(True), vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(EMBEDDING_MODEL), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), file_description: str = Body(''), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), ): """ 单个知识库根据文件名称摘要 - :param endpoint_host: - :param endpoint_host_key: - :param endpoint_host_proxy: :param model_name: :param max_tokens: :param temperature: @@ -146,17 +128,11 @@ def summary_file_to_vector_store( kb_summary.create_kb_summary() llm = get_ChatOpenAI( - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, ) reduce_llm = get_ChatOpenAI( - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, @@ -194,11 +170,8 @@ def summary_doc_ids_to_vector_store( knowledge_base_name: str = Body(..., examples=["samples"]), doc_ids: List = Body([], examples=[["uuid"]]), vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(EMBEDDING_MODEL), + embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), file_description: str = Body(''), - endpoint_host: str = Body(None, description="接入点地址"), - endpoint_host_key: str = Body(None, description="接入点key"), - endpoint_host_proxy: str = Body(None, description="接入点代理地址"), model_name: str = Body(None, description="LLM 模型名称。"), temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), @@ -206,9 +179,6 @@ def summary_doc_ids_to_vector_store( """ 单个知识库根据doc_ids摘要 :param knowledge_base_name: - :param endpoint_host: - :param endpoint_host_key: - :param endpoint_host_proxy: :param doc_ids: :param model_name: :param max_tokens: @@ -223,17 +193,11 @@ def summary_doc_ids_to_vector_store( return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={}) else: llm = get_ChatOpenAI( - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, ) reduce_llm = get_ChatOpenAI( - endpoint_host=endpoint_host, - endpoint_host_key=endpoint_host_key, - endpoint_host_proxy=endpoint_host_proxy, model_name=model_name, temperature=temperature, max_tokens=max_tokens, diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 596e1f62..a4817476 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -1,5 +1,5 @@ from configs import ( - EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, + DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose ) @@ -86,7 +86,7 @@ def folder2db( kb_names: List[str], mode: Literal["recreate_vs", "update_in_db", "increment"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, - embed_model: str = EMBEDDING_MODEL, + embed_model: str = DEFAULT_EMBEDDING_MODEL, chunk_size: int = CHUNK_SIZE, chunk_overlap: int = OVERLAP_SIZE, zh_title_enhance: bool = ZH_TITLE_ENHANCE, diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 7cb281e8..d37b52c7 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,4 +1,5 @@ import os +from functools import lru_cache from configs import ( KB_ROOT_PATH, CHUNK_SIZE, @@ -143,6 +144,7 @@ def get_LoaderClass(file_extension): if file_extension in extensions: return LoaderClass + def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): ''' 根据loader_name和文件路径或内容返回文档加载器。 @@ -184,6 +186,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): return loader +@lru_cache() def make_text_splitter( splitter_name, chunk_size, diff --git a/server/localai_embeddings.py b/server/localai_embeddings.py new file mode 100644 index 00000000..f52681a9 --- /dev/null +++ b/server/localai_embeddings.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import logging +import warnings +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain_community.utils.openai import is_openai_v1 +from tenacity import ( + AsyncRetrying, + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]: + import openai + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.Timeout) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.InternalServerError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any: + import openai + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + async_retrying = AsyncRetrying( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.Timeout) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.InternalServerError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + def wrap(func: Callable) -> Callable: + async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: + async for _ in async_retrying: + return await func(*args, **kwargs) + raise AssertionError("this is unreachable") + + return wrapped_f + + return wrap + + +# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings +def _check_response(response: dict) -> dict: + if any([len(d.embedding) == 1 for d in response.data]): + import openai + + raise openai.APIError("LocalAI API returned an empty embedding") + return response + + +def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the embedding call.""" + retry_decorator = _create_retry_decorator(embeddings) + + @retry_decorator + def _embed_with_retry(**kwargs: Any) -> Any: + response = embeddings.client.create(**kwargs) + return _check_response(response) + + return _embed_with_retry(**kwargs) + + +async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the embedding call.""" + + @_async_retry_decorator(embeddings) + async def _async_embed_with_retry(**kwargs: Any) -> Any: + response = await embeddings.async_client.acreate(**kwargs) + return _check_response(response) + + return await _async_embed_with_retry(**kwargs) + + +class LocalAIEmbeddings(BaseModel, Embeddings): + """LocalAI embedding models. + + Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class + uses the ``openai`` Python package's ``openai.Embedding`` as its client. + Thus, you should have the ``openai`` python package installed, and defeat + the environment variable ``OPENAI_API_KEY`` by setting to a random string. + You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI + service endpoint. + + Example: + .. code-block:: python + + from langchain_community.embeddings import LocalAIEmbeddings + openai = LocalAIEmbeddings( + openai_api_key="random-string", + openai_api_base="http://localhost:8080" + ) + + """ + + client: Any = Field(default=None, exclude=True) #: :meta private: + async_client: Any = Field(default=None, exclude=True) #: :meta private: + model: str = "text-embedding-ada-002" + deployment: str = model + openai_api_version: Optional[str] = None + openai_api_base: Optional[str] = Field(default=None, alias="base_url") + # to support explicit proxy for LocalAI + openai_proxy: Optional[str] = None + embedding_ctx_length: int = 8191 + """The maximum number of tokens to embed at once.""" + openai_api_key: Optional[str] = Field(default=None, alias="api_key") + openai_organization: Optional[str] = Field(default=None, alias="organization") + allowed_special: Union[Literal["all"], Set[str]] = set() + disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" + chunk_size: int = 1000 + """Maximum number of texts to embed in each batch""" + max_retries: int = 6 + """Maximum number of retries to make when generating.""" + request_timeout: Union[float, Tuple[float, float], Any, None] = Field( + default=None, alias="timeout" + ) + """Timeout in seconds for the LocalAI request.""" + headers: Any = None + show_progress_bar: bool = False + """Whether to show a progress bar when embedding.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + warnings.warn( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["openai_api_key"] = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) + values["openai_api_base"] = get_from_dict_or_env( + values, + "openai_api_base", + "OPENAI_API_BASE", + default="", + ) + values["openai_proxy"] = get_from_dict_or_env( + values, + "openai_proxy", + "OPENAI_PROXY", + default="", + ) + + default_api_version = "" + values["openai_api_version"] = get_from_dict_or_env( + values, + "openai_api_version", + "OPENAI_API_VERSION", + default=default_api_version, + ) + values["openai_organization"] = get_from_dict_or_env( + values, + "openai_organization", + "OPENAI_ORGANIZATION", + default="", + ) + try: + import openai + + if is_openai_v1(): + client_params = { + "api_key": values["openai_api_key"], + "organization": values["openai_organization"], + "base_url": values["openai_api_base"], + "timeout": values["request_timeout"], + "max_retries": values["max_retries"], + } + + if not values.get("client"): + values["client"] = openai.OpenAI(**client_params).embeddings + if not values.get("async_client"): + values["async_client"] = openai.AsyncOpenAI( + **client_params + ).embeddings + elif not values.get("client"): + values["client"] = openai.Embedding + else: + pass + except ImportError: + raise ImportError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + return values + + @property + def _invocation_params(self) -> Dict: + openai_args = { + "model": self.model, + "timeout": self.request_timeout, + "extra_headers": self.headers, + **self.model_kwargs, + } + if self.openai_proxy: + import openai + + openai.proxy = { + "http": self.openai_proxy, + "https": self.openai_proxy, + } # type: ignore[assignment] # noqa: E501 + return openai_args + + def _embedding_func(self, text: str, *, engine: str) -> List[float]: + """Call out to LocalAI's embedding endpoint.""" + # handle large input text + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + return embed_with_retry( + self, + input=[text], + **self._invocation_params, + ).data[0].embedding + + async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: + """Call out to LocalAI's embedding endpoint.""" + # handle large input text + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + return ( + await async_embed_with_retry( + self, + input=[text], + **self._invocation_params, + ) + ).data[0].embedding + + def embed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + """Call out to LocalAI's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + # call _embedding_func for each text + return [self._embedding_func(text, engine=self.deployment) for text in texts] + + async def aembed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + """Call out to LocalAI's embedding endpoint async for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + embeddings = [] + for text in texts: + response = await self._aembedding_func(text, engine=self.deployment) + embeddings.append(response) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Call out to LocalAI's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embedding = self._embedding_func(text, engine=self.deployment) + return embedding + + async def aembed_query(self, text: str) -> List[float]: + """Call out to LocalAI's embedding endpoint async for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embedding = await self._aembedding_func(text, engine=self.deployment) + return embedding diff --git a/server/reranker/reranker.py b/server/reranker/reranker.py index c6cbebfa..5a2de3c4 100644 --- a/server/reranker/reranker.py +++ b/server/reranker/reranker.py @@ -109,14 +109,13 @@ if __name__ == "__main__": RERANKER_MODEL, RERANKER_MAX_LENGTH, MODEL_PATH) - from server.utils import embedding_device if USE_RERANKER: reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large") print("-----------------model path------------------") print(reranker_model_path) reranker_model = LangchainReranker(top_n=3, - device=embedding_device(), + device="cpu", max_length=RERANKER_MAX_LENGTH, model_name_or_path=reranker_model_path ) diff --git a/server/utils.py b/server/utils.py index ab3a3dcc..05b68bbb 100644 --- a/server/utils.py +++ b/server/utils.py @@ -1,32 +1,28 @@ -import pydantic -from pydantic import BaseModel -from typing import List from fastapi import FastAPI from pathlib import Path import asyncio -from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE, - MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose, - HTTPX_DEFAULT_TIMEOUT) import os from concurrent.futures import ThreadPoolExecutor, as_completed +from langchain.pydantic_v1 import BaseModel, Field +from langchain.embeddings.base import Embeddings from langchain_openai.chat_models import ChatOpenAI -from langchain_community.llms import OpenAI +from langchain_openai.llms import OpenAI import httpx from typing import ( - TYPE_CHECKING, - Literal, Optional, Callable, Generator, Dict, + List, Any, Awaitable, Union, - Tuple + Tuple, + Literal, ) import logging -import torch +from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT, DEFAULT_EMBEDDING_MODEL from server.minx_chat_openai import MinxChatOpenAI @@ -44,10 +40,66 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): event.set() +def get_config_models( + model_name: str = None, + model_type: Literal["llm", "embed", "image", "multimodal"] = None, + platform_name: str = None, +) -> Dict[str, Dict]: + ''' + 获取配置的模型列表,返回值为: + {model_name: { + "platform_name": xx, + "platform_type": xx, + "model_type": xx, + "model_name": xx, + "api_base_url": xx, + "api_key": xx, + "api_proxy": xx, + }} + ''' + import importlib + from configs import model_config + importlib.reload(model_config) + + result = {} + for m in model_config.MODEL_PLATFORMS: + if platform_name is not None and platform_name != m.get("platform_name"): + continue + if model_type is not None and f"{model_type}_models" not in m: + continue + + if model_type is None: + model_types = ["llm_models", "embed_models", "image_models", "multimodal_models"] + else: + model_types = [f"{model_type}_models"] + + for m_type in model_types: + for m_name in m.get(m_type, []): + if model_name is None or model_name == m_name: + result[m_name] = { + "platform_name": m.get("platform_name"), + "platform_type": m.get("platform_type"), + "model_type": m_type.split("_")[0], + "model_name": m_name, + "api_base_url": m.get("api_base_url"), + "api_key": m.get("api_key"), + "api_proxy": m.get("api_proxy"), + } + return result + + +def get_model_info(model_name: str, platform_name: str = None) -> Dict: + ''' + 获取配置的模型信息,主要是 api_base_url, api_key + ''' + result = get_config_models(model_name=model_name, platform_name=platform_name) + if len(result) > 0: + return list(result.values())[0] + else: + return {} + + def get_ChatOpenAI( - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, model_name: str, temperature: float, max_tokens: int = None, @@ -56,29 +108,23 @@ def get_ChatOpenAI( verbose: bool = True, **kwargs: Any, ) -> ChatOpenAI: - config = get_model_worker_config(model_name) - if model_name == "openai-api": - model_name = config.get("model_name") - ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model + model_info = get_model_info(model_name) model = ChatOpenAI( streaming=streaming, verbose=verbose, callbacks=callbacks, - openai_api_key=endpoint_host_key if endpoint_host_key else "None", - openai_api_base=endpoint_host if endpoint_host else "None", model_name=model_name, temperature=temperature, max_tokens=max_tokens, - openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None, + openai_api_key=model_info.get("api_key"), + openai_api_base=model_info.get("api_base_url"), + openai_proxy=model_info.get("api_proxy"), **kwargs ) return model def get_OpenAI( - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str, model_name: str, temperature: float, max_tokens: int = None, @@ -89,22 +135,40 @@ def get_OpenAI( **kwargs: Any, ) -> OpenAI: # TODO: 从API获取模型信息 + model_info = get_model_info(model_name) model = OpenAI( streaming=streaming, verbose=verbose, callbacks=callbacks, - openai_api_key=endpoint_host_key if endpoint_host_key else "None", - openai_api_base=endpoint_host if endpoint_host else "None", model_name=model_name, temperature=temperature, max_tokens=max_tokens, - openai_proxy=endpoint_host_proxy if endpoint_host_proxy else None, + openai_api_key=model_info.get("api_key"), + openai_api_base=model_info.get("api_base_url"), + openai_proxy=model_info.get("api_proxy"), echo=echo, **kwargs ) return model +def get_Embeddings(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> Embeddings: + from langchain_community.embeddings.openai import OpenAIEmbeddings + from server.localai_embeddings import LocalAIEmbeddings # TODO: fork of lc pr #17154 + + model_info = get_model_info(model_name=embed_model) + params = { + "model": embed_model, + "base_url": model_info.get("api_base_url"), + "api_key": model_info.get("api_key"), + "openai_proxy": model_info.get("api_proxy"), + } + if model_info.get("platform_type") == "openai": + return OpenAIEmbeddings(**params) + else: + return LocalAIEmbeddings(**params) + + class MsgType: TEXT = 1 IMAGE = 2 @@ -113,9 +177,9 @@ class MsgType: class BaseResponse(BaseModel): - code: int = pydantic.Field(200, description="API status code") - msg: str = pydantic.Field("success", description="API status message") - data: Any = pydantic.Field(None, description="API data") + code: int = Field(200, description="API status code") + msg: str = Field("success", description="API status message") + data: Any = Field(None, description="API data") class Config: schema_extra = { @@ -127,7 +191,7 @@ class BaseResponse(BaseModel): class ListResponse(BaseResponse): - data: List[str] = pydantic.Field(..., description="List of names") + data: List[str] = Field(..., description="List of names") class Config: schema_extra = { @@ -140,10 +204,10 @@ class ListResponse(BaseResponse): class ChatMessage(BaseModel): - question: str = pydantic.Field(..., description="Question text") - response: str = pydantic.Field(..., description="Response text") - history: List[List[str]] = pydantic.Field(..., description="History text") - source_documents: List[str] = pydantic.Field( + question: str = Field(..., description="Question text") + response: str = Field(..., description="Response text") + history: List[List[str]] = Field(..., description="History text") + source_documents: List[str] = Field( ..., description="List of source documents and their scores" ) @@ -310,39 +374,40 @@ def MakeFastAPIOffline( # 从model_config中获取模型信息 +# TODO: 移出模型加载后,这些功能需要删除或改变实现 -def list_embed_models() -> List[str]: - ''' - get names of configured embedding models - ''' - return list(MODEL_PATH["embed_model"]) +# def list_embed_models() -> List[str]: +# ''' +# get names of configured embedding models +# ''' +# return list(MODEL_PATH["embed_model"]) -def get_model_path(model_name: str, type: str = None) -> Optional[str]: - if type in MODEL_PATH: - paths = MODEL_PATH[type] - else: - paths = {} - for v in MODEL_PATH.values(): - paths.update(v) +# def get_model_path(model_name: str, type: str = None) -> Optional[str]: +# if type in MODEL_PATH: +# paths = MODEL_PATH[type] +# else: +# paths = {} +# for v in MODEL_PATH.values(): +# paths.update(v) - if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径 - path = Path(path_str) - if path.is_dir(): # 任意绝对路径 - return str(path) +# if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径 +# path = Path(path_str) +# if path.is_dir(): # 任意绝对路径 +# return str(path) - root_path = Path(MODEL_ROOT_PATH) - if root_path.is_dir(): - path = root_path / model_name - if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b - return str(path) - path = root_path / path_str - if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new - return str(path) - path = root_path / path_str.split("/")[-1] - if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new - return str(path) - return path_str # THUDM/chatglm06b +# root_path = Path(MODEL_ROOT_PATH) +# if root_path.is_dir(): +# path = root_path / model_name +# if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b +# return str(path) +# path = root_path / path_str +# if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new +# return str(path) +# path = root_path / path_str.split("/")[-1] +# if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new +# return str(path) +# return path_str # THUDM/chatglm06b def api_address() -> str: @@ -429,37 +494,6 @@ def set_httpx_config( urllib.request.getproxies = _get_proxies -def detect_device() -> Literal["cuda", "mps", "cpu", "xpu"]: - try: - import torch - if torch.cuda.is_available(): - return "cuda" - if torch.backends.mps.is_available(): - return "mps" - import intel_extension_for_pytorch as ipex - if torch.xpu.get_device_properties(0): - return "xpu" - except: - pass - return "cpu" - - -def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: - device = device or LLM_DEVICE - # if device.isdigit(): - # return "cuda:" + device - if device not in ["cuda", "mps", "cpu", "xpu"]: - device = detect_device() - return device - - -def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: - device = device or EMBEDDING_DEVICE - if device not in ["cuda", "mps", "cpu", "xpu"]: - device = detect_device() - return device - - def run_in_thread_pool( func: Callable, params: List[Dict] = [], @@ -546,56 +580,19 @@ def get_server_configs() -> Dict: return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom} -def list_online_embed_models( - endpoint_host: str, - endpoint_host_key: str, - endpoint_host_proxy: str -) -> List[str]: - ret = [] - # TODO: 从在线API获取支持的模型列表 - client = get_httpx_client(base_url=endpoint_host, proxies=endpoint_host_proxy, timeout=HTTPX_DEFAULT_TIMEOUT) - try: - headers = { - "Authorization": f"Bearer {endpoint_host_key}", - } - resp = client.get("/models", headers=headers) - if resp.status_code == 200: - models = resp.json().get("data", []) - for model in models: - if "embedding" in model.get("id", None): - ret.append(model.get("id")) - - except Exception as e: - msg = f"获取在线Embeddings模型列表失败:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - finally: - client.close() - return ret - - -def load_local_embeddings(model: str = None, device: str = embedding_device()): - ''' - 从缓存中本地Embeddings模型加载,可以避免多线程时竞争加载。 - ''' - from server.knowledge_base.kb_cache.base import embeddings_pool - from configs import EMBEDDING_MODEL - - model = model or EMBEDDING_MODEL - return embeddings_pool.load_embeddings(model=model, device=device) - - def get_temp_dir(id: str = None) -> Tuple[str, str]: ''' 创建一个临时目录,返回(路径,文件夹名称) ''' from configs.basic_config import BASE_TEMP_DIR - import tempfile + import uuid if id is not None: # 如果指定的临时目录已存在,直接返回 path = os.path.join(BASE_TEMP_DIR, id) if os.path.isdir(path): return path, id - path = tempfile.mkdtemp(dir=BASE_TEMP_DIR) - return path, os.path.basename(path) + id = uuid.uuid4().hex + path = os.path.join(BASE_TEMP_DIR, id) + os.mkdir(path) + return path, id diff --git a/startup.py b/startup.py index aef47b7e..b1db023d 100644 --- a/startup.py +++ b/startup.py @@ -1,12 +1,11 @@ import asyncio +from contextlib import asynccontextmanager import multiprocessing as mp import os import subprocess import sys from multiprocessing import Process -from datetime import datetime -from pprint import pprint -from langchain_core._api import deprecated + # 设置numexpr最大线程数,默认为CPU核心数 try: @@ -17,38 +16,29 @@ try: except: pass -sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs import ( LOG_PATH, log_verbose, logger, - LLM_MODEL_CONFIG, - EMBEDDING_MODEL, + DEFAULT_EMBEDDING_MODEL, TEXT_SPLITTER_NAME, API_SERVER, WEBUI_SERVER, - HTTPX_DEFAULT_TIMEOUT, ) -from server.utils import (FastAPI, embedding_device) +from server.utils import FastAPI from server.knowledge_base.migrate import create_tables import argparse from typing import List, Dict from configs import VERSION -all_model_names = set() -for model_category in LLM_MODEL_CONFIG.values(): - for model_name in model_category.keys(): - if model_name not in all_model_names: - all_model_names.add(model_name) - -all_model_names_list = list(all_model_names) - def _set_app_event(app: FastAPI, started_event: mp.Event = None): - @app.on_event("startup") - async def on_startup(): + @asynccontextmanager + async def lifespan(app: FastAPI): if started_event is not None: started_event.set() + yield + app.router.lifespan_context = lifespan def run_api_server(started_event: mp.Event = None, run_mode: str = None): @@ -159,7 +149,7 @@ def dump_server_info(after_start=False, args=None): print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}") - print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") + print(f"当前Embbedings模型: {DEFAULT_EMBEDDING_MODEL}") if after_start: print("\n") @@ -232,13 +222,13 @@ async def start_main_server(): return len(processes) loom_started = manager.Event() - process = Process( - target=run_loom, - name=f"run_loom Server", - kwargs=dict(started_event=loom_started), - daemon=True, - ) - processes["run_loom"] = process + # process = Process( + # target=run_loom, + # name=f"run_loom Server", + # kwargs=dict(started_event=loom_started), + # daemon=True, + # ) + # processes["run_loom"] = process api_started = manager.Event() if args.api: process = Process( @@ -283,7 +273,6 @@ async def start_main_server(): # 等待所有进程退出 if p := processes.get("webui"): - p.join() except Exception as e: logger.error(e) @@ -306,9 +295,7 @@ async def start_main_server(): if __name__ == "__main__": - create_tables() - if sys.version_info < (3, 10): loop = asyncio.get_event_loop() else: diff --git a/webui.py b/webui.py index 9e300003..b305f96a 100644 --- a/webui.py +++ b/webui.py @@ -1,7 +1,7 @@ import streamlit as st -from webui_pages.loom_view_client import update_store -from webui_pages.openai_plugins import openai_plugins_page +# from webui_pages.loom_view_client import update_store +# from webui_pages.openai_plugins import openai_plugins_page from webui_pages.utils import * from streamlit_option_menu import option_menu from webui_pages.dialogue.dialogue import dialogue_page, chat_box @@ -12,9 +12,9 @@ from configs import VERSION from server.utils import api_address -def on_change(key): - if key: - update_store() +# def on_change(key): +# if key: +# update_store() api = ApiRequest(base_url=api_address()) @@ -59,18 +59,18 @@ if __name__ == "__main__": "icon": "hdd-stack", "func": knowledge_base_page, }, - "模型服务": { - "icon": "hdd-stack", - "func": openai_plugins_page, - }, + # "模型服务": { + # "icon": "hdd-stack", + # "func": openai_plugins_page, + # }, } # 更新状态 - if "status" not in st.session_state \ - or "run_plugins_list" not in st.session_state \ - or "launch_subscribe_info" not in st.session_state \ - or "list_running_models" not in st.session_state \ - or "model_config" not in st.session_state: - update_store() + # if "status" not in st.session_state \ + # or "run_plugins_list" not in st.session_state \ + # or "launch_subscribe_info" not in st.session_state \ + # or "list_running_models" not in st.session_state \ + # or "model_config" not in st.session_state: + # update_store() with st.sidebar: st.image( @@ -95,7 +95,6 @@ if __name__ == "__main__": icons=icons, # menu_icon="chat-quote", default_index=default_index, - on_change=on_change, ) if selected_page in pages: diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index c9759115..9aedb5ac 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -4,8 +4,8 @@ import streamlit as st from streamlit_antd_components.utils import ParseItems from webui_pages.dialogue.utils import process_files -from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \ - get_select_model_endpoint +# from webui_pages.loom_view_client import build_providers_model_plugins_name, find_menu_items_by_index, set_llm_select, \ +# get_select_model_endpoint from webui_pages.utils import * from streamlit_chatbox import * from streamlit_modal import Modal @@ -13,9 +13,9 @@ from datetime import datetime import os import re import time -from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG, OPENAI_KEY, OPENAI_PROXY) +from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS, TOOL_CONFIG) from server.callback_handler.agent_callback_handler import AgentStatus -from server.utils import MsgType +from server.utils import MsgType, get_config_models import uuid from typing import List, Dict @@ -111,8 +111,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): st.session_state.setdefault("conversation_ids", {}) st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex) st.session_state.setdefault("file_chat_id", None) - st.session_state.setdefault("select_plugins_info", None) - st.session_state.setdefault("select_model_worker", None) # 弹出自定义命令帮助信息 modal = Modal("自定义命令", key="cmd_help", max_width="500") @@ -131,18 +129,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): chat_box.use_chat_name(conversation_name) conversation_id = st.session_state["conversation_ids"][conversation_name] - with st.expander("模型选择"): - plugins_menu = build_providers_model_plugins_name() - - items, _ = ParseItems(plugins_menu).multi() - - if len(plugins_menu) > 0: - - llm_model_index = sac.menu(plugins_menu, index=1, return_index=True) - plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index) - set_llm_select(plugins_info, llm_model_worker) - else: - st.info("没有可用的插件") + platforms = [x["platform_name"] for x in MODEL_PLATFORMS] + platform = st.selectbox("选择模型平台", platforms, 1) + llm_models = list(get_config_models(model_type="llm", platform_name=platform)) + llm_model = st.selectbox("选择LLM模型", llm_models) # 传入后端的内容 chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} @@ -174,10 +164,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): if is_selected: selected_tool_configs[tool] = TOOL_CONFIG[tool] - llm_model = None - if st.session_state["select_model_worker"] is not None: - llm_model = st.session_state["select_model_worker"]['label'] - if llm_model is not None: chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) @@ -200,23 +186,23 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): chat_box.output_messages() chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 " - def on_feedback( - feedback, - message_id: str = "", - history_index: int = -1, - ): + # def on_feedback( + # feedback, + # message_id: str = "", + # history_index: int = -1, + # ): - reason = feedback["text"] - score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) - api.chat_feedback(message_id=message_id, - score=score_int, - reason=reason) - st.session_state["need_rerun"] = True + # reason = feedback["text"] + # score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) + # api.chat_feedback(message_id=message_id, + # score=score_int, + # reason=reason) + # st.session_state["need_rerun"] = True - feedback_kwargs = { - "feedback_type": "thumbs", - "optional_text_label": "欢迎反馈您打分的理由", - } + # feedback_kwargs = { + # "feedback_type": "thumbs", + # "optional_text_label": "欢迎反馈您打分的理由", + # } if prompt := st.chat_input(chat_input_placeholder, key="prompt"): if parse_command(text=prompt, modal=modal): # 用户输入自定义命令 @@ -244,17 +230,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): text_action = "" element_index = 0 - openai_config = {} - endpoint_host, select_model_name = get_select_model_endpoint() - openai_config["endpoint_host"] = endpoint_host - openai_config["model_name"] = select_model_name - openai_config["endpoint_host_key"] = OPENAI_KEY - openai_config["endpoint_host_proxy"] = OPENAI_PROXY for d in api.chat_chat(query=prompt, metadata=files_upload, history=history, chat_model_config=chat_model_config, - openai_config=openai_config, conversation_id=conversation_id, tool_config=selected_tool_configs, ): diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 37a7b1aa..e6212063 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -1,8 +1,8 @@ import streamlit as st from streamlit_antd_components.utils import ParseItems -from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \ - set_llm_select, set_embed_select, get_select_embed_endpoint +# from webui_pages.loom_view_client import build_providers_embedding_plugins_name, find_menu_items_by_index, \ +# set_llm_select, set_embed_select, get_select_embed_endpoint from webui_pages.utils import * from st_aggrid import AgGrid, JsCode from st_aggrid.grid_options_builder import GridOptionsBuilder @@ -10,10 +10,9 @@ import pandas as pd from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple -from configs import (kbs_config, - EMBEDDING_MODEL, DEFAULT_VS_TYPE, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, OPENAI_KEY, OPENAI_PROXY) -from server.utils import list_embed_models +from configs import (kbs_config, DEFAULT_VS_TYPE, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) +from server.utils import get_config_models import streamlit_antd_components as sac import os @@ -116,25 +115,11 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): col1, _ = st.columns([3, 1]) with col1: - col1.text("Embedding 模型") - plugins_menu = build_providers_embedding_plugins_name() - - embed_models = list_embed_models() - menu_item_children = [] - for model in embed_models: - menu_item_children.append(sac.MenuItem(model, description=model)) - - plugins_menu.append(sac.MenuItem("本地Embedding 模型", icon='box-fill', children=menu_item_children)) - - items, _ = ParseItems(plugins_menu).multi() - - if len(plugins_menu) > 0: - - llm_model_index = sac.menu(plugins_menu, index=1, return_index=True, height=300, open_all=False) - plugins_info, llm_model_worker = find_menu_items_by_index(items, llm_model_index) - set_embed_select(plugins_info, llm_model_worker) - else: - st.info("没有可用的插件") + embed_models = list(get_config_models(model_type="embed")) + index = 0 + if DEFAULT_EMBEDDING_MODEL in embed_models: + index = embed_models.index(DEFAULT_EMBEDDING_MODEL) + embed_model = st.selectbox("Embeddings模型", embed_models, index) submit_create_kb = st.form_submit_button( "新建", @@ -143,23 +128,17 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): ) if submit_create_kb: - - endpoint_host, select_embed_model_name = get_select_embed_endpoint() if not kb_name or not kb_name.strip(): st.error(f"知识库名称不能为空!") elif kb_name in kb_list: st.error(f"名为 {kb_name} 的知识库已经存在!") - elif select_embed_model_name is None: + elif embed_model is None: st.error(f"请选择Embedding模型!") else: - ret = api.create_knowledge_base( knowledge_base_name=kb_name, vector_store_type=vs_type, - embed_model=select_embed_model_name, - endpoint_host=endpoint_host, - endpoint_host_key=OPENAI_KEY, - endpoint_host_proxy=OPENAI_PROXY, + embed_model=embed_model, ) st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name @@ -169,9 +148,6 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): elif selected_kb: kb = selected_kb st.session_state["selected_kb_info"] = kb_list[kb]['kb_info'] - st.session_state["kb_endpoint_host"] = kb_list[kb]['endpoint_host'] - st.session_state["kb_endpoint_host_key"] = kb_list[kb]['endpoint_host_key'] - st.session_state["kb_endpoint_host_proxy"] = kb_list[kb]['endpoint_host_proxy'] # 上传文件 files = st.file_uploader("上传知识文件:", [i for ls in LOADER_DICT.values() for i in ls], @@ -185,37 +161,6 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): st.session_state["selected_kb_info"] = kb_info api.update_kb_info(kb, kb_info) - if st.session_state["kb_endpoint_host"] is not None: - with st.expander( - "在线api接入点配置", - expanded=True, - ): - endpoint_host = st.text_input( - "接入点地址", - placeholder="接入点地址", - key="endpoint_host", - value=st.session_state["kb_endpoint_host"], - ) - endpoint_host_key = st.text_input( - "接入点key", - placeholder="接入点key", - key="endpoint_host_key", - value=st.session_state["kb_endpoint_host_key"], - ) - endpoint_host_proxy = st.text_input( - "接入点代理地址", - placeholder="接入点代理地址", - key="endpoint_host_proxy", - value=st.session_state["kb_endpoint_host_proxy"], - ) - if endpoint_host != st.session_state["kb_endpoint_host"] \ - or endpoint_host_key != st.session_state["kb_endpoint_host_key"] \ - or endpoint_host_proxy != st.session_state["kb_endpoint_host_proxy"]: - st.session_state["kb_endpoint_host"] = endpoint_host - st.session_state["kb_endpoint_host_key"] = endpoint_host_key - st.session_state["kb_endpoint_host_proxy"] = endpoint_host_proxy - api.update_kb_endpoint(kb, endpoint_host, endpoint_host_key, endpoint_host_proxy) - # with st.sidebar: with st.expander( "文件处理配置", diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 40d1377b..8ff988a7 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -4,7 +4,7 @@ from typing import * from pathlib import Path from configs import ( - EMBEDDING_MODEL, + DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, LLM_MODEL_CONFIG, SCORE_THRESHOLD, @@ -266,7 +266,6 @@ class ApiRequest: history: List[Dict] = [], stream: bool = True, chat_model_config: Dict = None, - openai_config: Dict = None, tool_config: Dict = None, **kwargs, ): @@ -281,7 +280,6 @@ class ApiRequest: "history": history, "stream": stream, "chat_model_config": chat_model_config, - "openai_config": openai_config, "tool_config": tool_config, } @@ -381,10 +379,7 @@ class ApiRequest: self, knowledge_base_name: str, vector_store_type: str = DEFAULT_VS_TYPE, - embed_model: str = EMBEDDING_MODEL, - endpoint_host: str = None, - endpoint_host_key: str = None, - endpoint_host_proxy: str = None + embed_model: str = DEFAULT_EMBEDDING_MODEL, ): ''' 对应api.py/knowledge_base/create_knowledge_base接口 @@ -393,9 +388,6 @@ class ApiRequest: "knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "embed_model": embed_model, - "endpoint_host": endpoint_host, - "endpoint_host_key": endpoint_host_key, - "endpoint_host_proxy": endpoint_host_proxy, } response = self.post( @@ -459,24 +451,6 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True) - def update_docs_by_id( - self, - knowledge_base_name: str, - docs: Dict[str, Dict], - ) -> bool: - ''' - 对应api.py/knowledge_base/update_docs_by_id接口 - ''' - data = { - "knowledge_base_name": knowledge_base_name, - "docs": docs, - } - response = self.post( - "/knowledge_base/update_docs_by_id", - json=data - ) - return self._get_response_value(response) - def upload_kb_docs( self, files: List[Union[str, Path, bytes]], @@ -562,26 +536,6 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True) - def update_kb_endpoint(self, - knowledge_base_name, - endpoint_host: str = None, - endpoint_host_key: str = None, - endpoint_host_proxy: str = None): - ''' - 对应api.py/knowledge_base/update_info接口 - ''' - data = { - "knowledge_base_name": knowledge_base_name, - "endpoint_host": endpoint_host, - "endpoint_host_key": endpoint_host_key, - "endpoint_host_proxy": endpoint_host_proxy, - } - - response = self.post( - "/knowledge_base/update_kb_endpoint", - json=data, - ) - return self._get_response_value(response, as_json=True) def update_kb_docs( self, knowledge_base_name: str, @@ -621,7 +575,7 @@ class ApiRequest: knowledge_base_name: str, allow_empty_kb: bool = True, vs_type: str = DEFAULT_VS_TYPE, - embed_model: str = EMBEDDING_MODEL, + embed_model: str = DEFAULT_EMBEDDING_MODEL, chunk_size=CHUNK_SIZE, chunk_overlap=OVERLAP_SIZE, zh_title_enhance=ZH_TITLE_ENHANCE, @@ -650,7 +604,7 @@ class ApiRequest: def embed_texts( self, texts: List[str], - embed_model: str = EMBEDDING_MODEL, + embed_model: str = DEFAULT_EMBEDDING_MODEL, to_query: bool = False, ) -> List[List[float]]: '''