From b51ba11f4577a328d86fba47bd855e5447be8442 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Thu, 9 Nov 2023 22:15:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=80=9A=E8=BF=87=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E9=A1=B9=E5=90=8C=E6=97=B6=E5=90=AF=E5=8A=A8=E5=A4=9A?= =?UTF-8?q?=E4=B8=AA=E6=A8=A1=E5=9E=8B=EF=BC=8C=E5=B0=86Wiki=E7=BA=B3?= =?UTF-8?q?=E5=85=A5samples=E7=9F=A5=E8=AF=86=E5=BA=93=20(#2002)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新功能: - 将 LLM_MODEL 配置项改为 LLM_MODELS 列表,同时启动多个模型 - 将 wiki 纳入 samples 知识库 依赖变化: - 指定 streamlit~=1.27.0。1.26.0会报rerun错误,1.28.0会有无限刷新错误 修复优化: - 优化 get_default_llm_model 逻辑 - 适配 Qwen 在线 API 做 Embeddings 时最大 25 行的限制 - 列出知识库磁盘文件时跳过 . 开头的文件 --- .gitignore | 5 +- .gitmodules | 3 + chains/llmchain_with_history.py | 4 +- configs/basic_config.py.example | 1 + configs/kb_config.py.example | 3 + configs/model_config.py.example | 187 +++++++++--------- configs/server_config.py.example | 50 ++--- .../test_files}/langchain-ChatGLM_closed.csv | 0 .../langchain-ChatGLM_closed.jsonl | 0 .../test_files}/langchain-ChatGLM_closed.xlsx | Bin .../test_files}/langchain-ChatGLM_open.csv | 0 .../test_files}/langchain-ChatGLM_open.jsonl | 0 .../test_files}/langchain-ChatGLM_open.xlsx | Bin .../samples/content/{ => test_files}/test.txt | 0 knowledge_base/samples/content/wiki | 1 + requirements.txt | 2 +- requirements_lite.txt | 2 +- requirements_webui.txt | 2 +- server/chat/agent_chat.py | 6 +- server/chat/chat.py | 6 +- server/chat/completion.py | 6 +- server/chat/knowledge_base_chat.py | 6 +- server/chat/openai_chat.py | 4 +- server/chat/search_engine_chat.py | 6 +- .../kb_service/faiss_kb_service.py | 12 +- server/knowledge_base/utils.py | 7 +- server/llm_api.py | 8 +- server/model_workers/qwen.py | 26 ++- server/utils.py | 9 +- startup.py | 14 +- tests/agent/test_agent_function.py | 4 +- tests/api/test_llm_api.py | 1 - webui_pages/utils.py | 57 ++++-- 33 files changed, 238 insertions(+), 194 deletions(-) create mode 100644 .gitmodules rename knowledge_base/samples/{isssues_merge => content/test_files}/langchain-ChatGLM_closed.csv (100%) rename knowledge_base/samples/{isssues_merge => content/test_files}/langchain-ChatGLM_closed.jsonl (100%) rename knowledge_base/samples/{isssues_merge => content/test_files}/langchain-ChatGLM_closed.xlsx (100%) rename knowledge_base/samples/{isssues_merge => content/test_files}/langchain-ChatGLM_open.csv (100%) rename knowledge_base/samples/{isssues_merge => content/test_files}/langchain-ChatGLM_open.jsonl (100%) rename knowledge_base/samples/{isssues_merge => content/test_files}/langchain-ChatGLM_open.xlsx (100%) rename knowledge_base/samples/content/{ => test_files}/test.txt (100%) create mode 160000 knowledge_base/samples/content/wiki diff --git a/.gitignore b/.gitignore index ced8d9e0..10953ca9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,10 @@ *.log.* *.bak logs -/knowledge_base/ +/knowledge_base/* +!/knowledge_base/samples +/knowledge_base/samples/vector_store + /configs/*.py .vscode/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..6d898201 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "knowledge_base/samples/content/wiki"] + path = knowledge_base/samples/content/wiki + url = https://github.com/chatchat-space/Langchain-Chatchat.wiki.git diff --git a/chains/llmchain_with_history.py b/chains/llmchain_with_history.py index 9707c00c..2a845086 100644 --- a/chains/llmchain_with_history.py +++ b/chains/llmchain_with_history.py @@ -1,12 +1,12 @@ from server.utils import get_ChatOpenAI -from configs.model_config import LLM_MODEL, TEMPERATURE +from configs.model_config import LLM_MODELS, TEMPERATURE from langchain.chains import LLMChain from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, ) -model = get_ChatOpenAI(model_name=LLM_MODEL, temperature=TEMPERATURE) +model = get_ChatOpenAI(model_name=LLM_MODELS[0], temperature=TEMPERATURE) human_prompt = "{input}" diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example index 3540872c..ba9e046b 100644 --- a/configs/basic_config.py.example +++ b/configs/basic_config.py.example @@ -2,6 +2,7 @@ import logging import os import langchain + # 是否显示详细日志 log_verbose = False langchain.verbose = False diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index 4e9cfe23..530433be 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -56,7 +56,10 @@ KB_INFO = { "知识库名称": "知识库介绍", "samples": "关于本项目issue的解答", } + + # 通常情况下不需要更改以下内容 + # 知识库默认存储路径 KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") if not os.path.exists(KB_ROOT_PATH): diff --git a/configs/model_config.py.example b/configs/model_config.py.example index ff1986cc..e6d78b92 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -1,96 +1,13 @@ import os + # 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。 -# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录 +# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录。 +# 如果模型目录名称和 MODEL_PATH 中的 key 或 value 相同,程序会自动检测加载,无需修改 MODEL_PATH 中的路径。 MODEL_ROOT_PATH = "" -# 在以下字典中修改属性值,以指定本地embedding模型存储位置。支持3种设置方法: -# 1、将对应的值修改为模型绝对路径 -# 2、不修改此处的值(以 text2vec 为例): -# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录: -# - text2vec -# - GanymedeNil/text2vec-large-chinese -# - text2vec-large-chinese -# 2.2 如果以上本地路径不存在,则使用huggingface模型 -MODEL_PATH = { - "embed_model": { - "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", - "ernie-base": "nghuyong/ernie-3.0-base-zh", - "text2vec-base": "shibing624/text2vec-base-chinese", - "text2vec": "GanymedeNil/text2vec-large-chinese", - "text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase", - "text2vec-sentence": "shibing624/text2vec-base-chinese-sentence", - "text2vec-multilingual": "shibing624/text2vec-base-multilingual", - "text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese", - "m3e-small": "moka-ai/m3e-small", - "m3e-base": "moka-ai/m3e-base", - "m3e-large": "moka-ai/m3e-large", - "bge-small-zh": "BAAI/bge-small-zh", - "bge-base-zh": "BAAI/bge-base-zh", - "bge-large-zh": "BAAI/bge-large-zh", - "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", - "bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", - "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", - "piccolo-base-zh": "sensenova/piccolo-base-zh", - "piccolo-large-zh": "sensenova/piccolo-large-zh", - "text-embedding-ada-002": "your OPENAI_API_KEY", - }, - # TODO: add all supported llm models - "llm_model": { - # 以下部分模型并未完全测试,仅根据fastchat和vllm模型的模型列表推定支持 - "chatglm2-6b": "THUDM/chatglm2-6b", - "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", - "chatglm3-6b": "THUDM/chatglm3-6b-32k", - "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", - - "baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat", - "baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat", - - "baichuan-7b": "baichuan-inc/Baichuan-7B", - "baichuan-13b": "baichuan-inc/Baichuan-13B", - 'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat', - - "aquila-7b": "BAAI/Aquila-7B", - "aquilachat-7b": "BAAI/AquilaChat-7B", - - "internlm-7b": "internlm/internlm-7b", - "internlm-chat-7b": "internlm/internlm-chat-7b", - - "falcon-7b": "tiiuae/falcon-7b", - "falcon-40b": "tiiuae/falcon-40b", - "falcon-rw-7b": "tiiuae/falcon-rw-7b", - - "gpt2": "gpt2", - "gpt2-xl": "gpt2-xl", - - "gpt-j-6b": "EleutherAI/gpt-j-6b", - "gpt4all-j": "nomic-ai/gpt4all-j", - "gpt-neox-20b": "EleutherAI/gpt-neox-20b", - "pythia-12b": "EleutherAI/pythia-12b", - "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", - "dolly-v2-12b": "databricks/dolly-v2-12b", - "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", - - "Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf", - "Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf", - "open_llama_13b": "openlm-research/open_llama_13b", - "vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3", - "koala": "young-geng/koala", - - "mpt-7b": "mosaicml/mpt-7b", - "mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter", - "mpt-30b": "mosaicml/mpt-30b", - "opt-66b": "facebook/opt-66b", - "opt-iml-max-30b": "facebook/opt-iml-max-30b", - - "Qwen-7B": "Qwen/Qwen-7B", - "Qwen-14B": "Qwen/Qwen-14B", - "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", - "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", - }, -} # 选用的 Embedding 名称 -EMBEDDING_MODEL = "m3e-base" # 可以尝试最新的嵌入式sota模型:bge-large-zh-v1.5 +EMBEDDING_MODEL = "m3e-base" # bge-large-zh # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 EMBEDDING_DEVICE = "auto" @@ -99,9 +16,11 @@ EMBEDDING_DEVICE = "auto" EMBEDDING_KEYWORD_FILE = "keywords.txt" EMBEDDING_MODEL_OUTPUT_PATH = "output" -# LLM 名称 -LLM_MODEL = "chatglm2-6b" -# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODEL) +# 要运行的 LLM 名称,可以包括本地模型和在线模型。 +# 第一个将作为 API 和 WEBUI 的默认模型 +LLM_MODELS = ["chatglm2-6b-int4", "zhipu-api", "openai-api] + +# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0]) Agent_MODEL = None # LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 @@ -111,7 +30,6 @@ LLM_DEVICE = "auto" HISTORY_LEN = 3 # 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度 - MAX_TOKENS = None # LLM通用对话参数 @@ -197,6 +115,93 @@ ONLINE_LLM_MODEL = { }, } +# 在以下字典中修改属性值,以指定本地embedding模型存储位置。支持3种设置方法: +# 1、将对应的值修改为模型绝对路径 +# 2、不修改此处的值(以 text2vec 为例): +# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录: +# - text2vec +# - GanymedeNil/text2vec-large-chinese +# - text2vec-large-chinese +# 2.2 如果以上本地路径不存在,则使用huggingface模型 +MODEL_PATH = { + "embed_model": { + "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", + "ernie-base": "nghuyong/ernie-3.0-base-zh", + "text2vec-base": "shibing624/text2vec-base-chinese", + "text2vec": "GanymedeNil/text2vec-large-chinese", + "text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase", + "text2vec-sentence": "shibing624/text2vec-base-chinese-sentence", + "text2vec-multilingual": "shibing624/text2vec-base-multilingual", + "text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese", + "m3e-small": "moka-ai/m3e-small", + "m3e-base": "moka-ai/m3e-base", + "m3e-large": "moka-ai/m3e-large", + "bge-small-zh": "BAAI/bge-small-zh", + "bge-base-zh": "BAAI/bge-base-zh", + "bge-large-zh": "BAAI/bge-large-zh", + "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", + "bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", + "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", + "piccolo-base-zh": "sensenova/piccolo-base-zh", + "piccolo-large-zh": "sensenova/piccolo-large-zh", + "text-embedding-ada-002": "your OPENAI_API_KEY", + }, + + "llm_model": { + # 以下部分模型并未完全测试,仅根据fastchat和vllm模型的模型列表推定支持 + "chatglm2-6b": "THUDM/chatglm2-6b", + "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", + "chatglm3-6b": "THUDM/chatglm3-6b-32k", + "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", + + "baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat", + "baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat", + + "baichuan-7b": "baichuan-inc/Baichuan-7B", + "baichuan-13b": "baichuan-inc/Baichuan-13B", + 'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat', + + "aquila-7b": "BAAI/Aquila-7B", + "aquilachat-7b": "BAAI/AquilaChat-7B", + + "internlm-7b": "internlm/internlm-7b", + "internlm-chat-7b": "internlm/internlm-chat-7b", + + "falcon-7b": "tiiuae/falcon-7b", + "falcon-40b": "tiiuae/falcon-40b", + "falcon-rw-7b": "tiiuae/falcon-rw-7b", + + "gpt2": "gpt2", + "gpt2-xl": "gpt2-xl", + + "gpt-j-6b": "EleutherAI/gpt-j-6b", + "gpt4all-j": "nomic-ai/gpt4all-j", + "gpt-neox-20b": "EleutherAI/gpt-neox-20b", + "pythia-12b": "EleutherAI/pythia-12b", + "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", + "dolly-v2-12b": "databricks/dolly-v2-12b", + "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", + + "Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf", + "Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf", + "open_llama_13b": "openlm-research/open_llama_13b", + "vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3", + "koala": "young-geng/koala", + + "mpt-7b": "mosaicml/mpt-7b", + "mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter", + "mpt-30b": "mosaicml/mpt-30b", + "opt-66b": "facebook/opt-66b", + "opt-iml-max-30b": "facebook/opt-iml-max-30b", + + "Qwen-7B": "Qwen/Qwen-7B", + "Qwen-14B": "Qwen/Qwen-14B", + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", + }, +} + + # 通常情况下不需要更改以下内容 # nltk 模型存储路径 diff --git a/configs/server_config.py.example b/configs/server_config.py.example index e7bfb5ae..714a3432 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -31,8 +31,7 @@ FSCHAT_OPENAI_API = { # fastchat model_worker server # 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。 -# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL -# 必须在这里添加的模型才会出现在WEBUI中可选模型列表里(LLM_MODEL会自动添加) +# 在启动startup.py时,可用通过`--model-name xxxx yyyy`指定模型,不指定则为LLM_MODELS FSCHAT_MODEL_WORKERS = { # 所有模型共用的默认配置,可在模型专项配置中进行覆盖。 "default": { @@ -58,7 +57,7 @@ FSCHAT_MODEL_WORKERS = { # "awq_ckpt": None, # "awq_wbits": 16, # "awq_groupsize": -1, - # "model_names": [LLM_MODEL], + # "model_names": LLM_MODELS, # "conv_template": None, # "limit_worker_concurrency": 5, # "stream_interval": 2, @@ -96,30 +95,31 @@ FSCHAT_MODEL_WORKERS = { # "device": "cpu", # }, - "zhipu-api": { # 请为每个要运行的在线API设置不同的端口 + #以下配置可以不用修改,在model_config中设置启动的模型 + "zhipu-api": { "port": 21001, }, - # "minimax-api": { - # "port": 21002, - # }, - # "xinghuo-api": { - # "port": 21003, - # }, - # "qianfan-api": { - # "port": 21004, - # }, - # "fangzhou-api": { - # "port": 21005, - # }, - # "qwen-api": { - # "port": 21006, - # }, - # "baichuan-api": { - # "port": 21007, - # }, - # "azure-api": { - # "port": 21008, - # }, + "minimax-api": { + "port": 21002, + }, + "xinghuo-api": { + "port": 21003, + }, + "qianfan-api": { + "port": 21004, + }, + "fangzhou-api": { + "port": 21005, + }, + "qwen-api": { + "port": 21006, + }, + "baichuan-api": { + "port": 21007, + }, + "azure-api": { + "port": 21008, + }, } # fastchat multi model worker server diff --git a/knowledge_base/samples/isssues_merge/langchain-ChatGLM_closed.csv b/knowledge_base/samples/content/test_files/langchain-ChatGLM_closed.csv similarity index 100% rename from knowledge_base/samples/isssues_merge/langchain-ChatGLM_closed.csv rename to knowledge_base/samples/content/test_files/langchain-ChatGLM_closed.csv diff --git a/knowledge_base/samples/isssues_merge/langchain-ChatGLM_closed.jsonl b/knowledge_base/samples/content/test_files/langchain-ChatGLM_closed.jsonl similarity index 100% rename from knowledge_base/samples/isssues_merge/langchain-ChatGLM_closed.jsonl rename to knowledge_base/samples/content/test_files/langchain-ChatGLM_closed.jsonl diff --git a/knowledge_base/samples/isssues_merge/langchain-ChatGLM_closed.xlsx b/knowledge_base/samples/content/test_files/langchain-ChatGLM_closed.xlsx similarity index 100% rename from knowledge_base/samples/isssues_merge/langchain-ChatGLM_closed.xlsx rename to knowledge_base/samples/content/test_files/langchain-ChatGLM_closed.xlsx diff --git a/knowledge_base/samples/isssues_merge/langchain-ChatGLM_open.csv b/knowledge_base/samples/content/test_files/langchain-ChatGLM_open.csv similarity index 100% rename from knowledge_base/samples/isssues_merge/langchain-ChatGLM_open.csv rename to knowledge_base/samples/content/test_files/langchain-ChatGLM_open.csv diff --git a/knowledge_base/samples/isssues_merge/langchain-ChatGLM_open.jsonl b/knowledge_base/samples/content/test_files/langchain-ChatGLM_open.jsonl similarity index 100% rename from knowledge_base/samples/isssues_merge/langchain-ChatGLM_open.jsonl rename to knowledge_base/samples/content/test_files/langchain-ChatGLM_open.jsonl diff --git a/knowledge_base/samples/isssues_merge/langchain-ChatGLM_open.xlsx b/knowledge_base/samples/content/test_files/langchain-ChatGLM_open.xlsx similarity index 100% rename from knowledge_base/samples/isssues_merge/langchain-ChatGLM_open.xlsx rename to knowledge_base/samples/content/test_files/langchain-ChatGLM_open.xlsx diff --git a/knowledge_base/samples/content/test.txt b/knowledge_base/samples/content/test_files/test.txt similarity index 100% rename from knowledge_base/samples/content/test.txt rename to knowledge_base/samples/content/test_files/test.txt diff --git a/knowledge_base/samples/content/wiki b/knowledge_base/samples/content/wiki new file mode 160000 index 00000000..b705cf80 --- /dev/null +++ b/knowledge_base/samples/content/wiki @@ -0,0 +1 @@ +Subproject commit b705cf80e4150cb900c77b343f0f9c62ec9a0278 diff --git a/requirements.txt b/requirements.txt index 4ae2df00..787c7ea8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,7 +53,7 @@ vllm>=0.2.0; sys_platform == "linux" # WebUI requirements -streamlit>=1.26.0 +streamlit~=1.27.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.1.11 streamlit-chatbox>=1.1.11 diff --git a/requirements_lite.txt b/requirements_lite.txt index 73ad8527..c3532f2e 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -41,7 +41,7 @@ dashscope>=1.10.0 # qwen numpy~=1.24.4 pandas~=2.0.3 -streamlit>=1.26.0 +streamlit~=1.27.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.1.11 streamlit-chatbox==1.1.11 diff --git a/requirements_webui.txt b/requirements_webui.txt index 608318a6..65ff7526 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -1,6 +1,6 @@ # WebUI requirements -streamlit>=1.26.0 +streamlit~=1.27.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.1.11 streamlit-chatbox>=1.1.11 diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 5c5983c0..ccbc1e19 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -5,7 +5,7 @@ from langchain.agents import AgentExecutor, LLMSingleActionAgent, initialize_age from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate from fastapi import Body from fastapi.responses import StreamingResponse -from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, Agent_MODEL +from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from langchain.chains import LLMChain from typing import AsyncIterable, Optional, Dict @@ -26,7 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples "content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]] ), stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", @@ -38,7 +38,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples async def agent_chat_iterator( query: str, history: Optional[List[History]], - model_name: str = LLM_MODEL, + model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = CustomAsyncIteratorCallbackHandler() diff --git a/server/chat/chat.py b/server/chat/chat.py index 8659af1b..b442664b 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,6 +1,6 @@ from fastapi import Body from fastapi.responses import StreamingResponse -from configs import LLM_MODEL, TEMPERATURE, SAVE_CHAT_HISTORY +from configs import LLM_MODELS, TEMPERATURE, SAVE_CHAT_HISTORY from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler @@ -22,7 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 {"role": "assistant", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), @@ -32,7 +32,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 async def chat_iterator(query: str, history: List[History] = [], - model_name: str = LLM_MODEL, + model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() diff --git a/server/chat/completion.py b/server/chat/completion.py index a24858f7..ee5e2d12 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -1,6 +1,6 @@ from fastapi import Body from fastapi.responses import StreamingResponse -from configs import LLM_MODEL, TEMPERATURE +from configs import LLM_MODELS, TEMPERATURE from server.utils import wrap_done, get_OpenAI from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler @@ -13,7 +13,7 @@ from server.utils import get_prompt_template async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), stream: bool = Body(False, description="流式输出"), echo: bool = Body(False, description="除了输出之外,还回显输入"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), @@ -23,7 +23,7 @@ async def completion(query: str = Body(..., description="用户输入", examples #todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 async def completion_iterator(query: str, - model_name: str = LLM_MODEL, + model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, echo: bool = echo, ) -> AsyncIterable[str]: diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 664217bf..b607b1a5 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,6 +1,6 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse -from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) +from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template from langchain.chains import LLMChain @@ -30,7 +30,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), @@ -45,7 +45,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", async def knowledge_base_chat_iterator(query: str, top_k: int, history: Optional[List[History]], - model_name: str = LLM_MODEL, + model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index 4d6c58d0..157880e4 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -1,7 +1,7 @@ from fastapi.responses import StreamingResponse from typing import List, Optional import openai -from configs import LLM_MODEL, logger, log_verbose +from configs import LLM_MODELS, logger, log_verbose from server.utils import get_model_worker_config, fschat_openai_api_address from pydantic import BaseModel @@ -12,7 +12,7 @@ class OpenAiMessage(BaseModel): class OpenAiChatMsgIn(BaseModel): - model: str = LLM_MODEL + model: str = LLM_MODELS[0] messages: List[OpenAiMessage] temperature: float = 0.7 n: int = 1 diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index fcda1fac..8325b4d9 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -1,7 +1,7 @@ from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, - LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE, + LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, TEXT_SPLITTER_NAME, OVERLAP_SIZE) from fastapi import Body from fastapi.responses import StreamingResponse @@ -126,7 +126,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), @@ -144,7 +144,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", search_engine_name: str, top_k: int, history: Optional[List[History]], - model_name: str = LLM_MODEL, + model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index d1186209..07c57e05 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -48,7 +48,10 @@ class FaissKBService(KBService): def do_drop_kb(self): self.clear_vs() - shutil.rmtree(self.kb_path) + try: + shutil.rmtree(self.kb_path) + except Exception: + ... def do_search(self, query: str, @@ -90,8 +93,11 @@ class FaissKBService(KBService): def do_clear_vs(self): with kb_faiss_pool.atomic: kb_faiss_pool.pop((self.kb_name, self.vector_name)) - shutil.rmtree(self.vs_path) - os.makedirs(self.vs_path) + try: + shutil.rmtree(self.vs_path) + except Exception: + ... + os.makedirs(self.vs_path, exist_ok=True) def exist_doc(self, file_name: str): if super().exist_doc(file_name): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 688db55b..d6720ec1 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -7,7 +7,7 @@ from configs import ( logger, log_verbose, text_splitter_dict, - LLM_MODEL, + LLM_MODELS, TEXT_SPLITTER_NAME, ) import importlib @@ -57,7 +57,8 @@ def list_files_from_folder(kb_name: str): for root, _, files in os.walk(doc_path): tail = os.path.basename(root).lower() if (tail.startswith("temp") - or tail.startswith("tmp")): # 跳过 temp 或 tmp 开头的文件夹 + or tail.startswith("tmp") + or tail.startswith(".")): # 跳过 [temp, tmp, .] 开头的文件夹 continue for file in files: if file.startswith("~$"): # 跳过 ~$ 开头的文件 @@ -192,7 +193,7 @@ def make_text_splitter( splitter_name: str = TEXT_SPLITTER_NAME, chunk_size: int = CHUNK_SIZE, chunk_overlap: int = OVERLAP_SIZE, - llm_model: str = LLM_MODEL, + llm_model: str = LLM_MODELS[0], ): """ 根据参数获取特定的分词器 diff --git a/server/llm_api.py b/server/llm_api.py index 0b78fc29..015a1c0a 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,5 +1,5 @@ from fastapi import Body -from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT +from configs import logger, log_verbose, LLM_MODELS, HTTPX_DEFAULT_TIMEOUT from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, get_httpx_client, get_model_worker_config) from copy import deepcopy @@ -65,7 +65,7 @@ def get_model_config( def stop_llm_model( - model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), + model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODELS[0]]), controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) ) -> BaseResponse: ''' @@ -89,8 +89,8 @@ def stop_llm_model( def change_llm_model( - model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]), - new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]), + model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODELS[0]]), + new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODELS[0]]), controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) ): ''' diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py index 2cb1ed0c..5c68791a 100644 --- a/server/model_workers/qwen.py +++ b/server/model_workers/qwen.py @@ -59,16 +59,22 @@ class QwenWorker(ApiModelWorker): import dashscope params.load_config(self.model_names[0]) - resp = dashscope.TextEmbedding.call( - model=params.embed_model or self.DEFAULT_EMBED_MODEL, - input=params.texts, # 最大25行 - api_key=params.api_key, - ) - if resp["status_code"] != 200: - return {"code": resp["status_code"], "msg": resp.message} - else: - embeddings = [x["embedding"] for x in resp["output"]["embeddings"]] - return {"code": 200, "data": embeddings} + result = [] + i = 0 + while i < len(params.texts): + texts = params.texts[i:i+25] + resp = dashscope.TextEmbedding.call( + model=params.embed_model or self.DEFAULT_EMBED_MODEL, + input=texts, # 最大25行 + api_key=params.api_key, + ) + if resp["status_code"] != 200: + return {"code": resp["status_code"], "msg": resp.message} + else: + embeddings = [x["embedding"] for x in resp["output"]["embeddings"]] + result += embeddings + i += 25 + return {"code": 200, "data": result} def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/utils.py b/server/utils.py index 99fbac11..df4ee035 100644 --- a/server/utils.py +++ b/server/utils.py @@ -4,7 +4,7 @@ from typing import List from fastapi import FastAPI from pathlib import Path import asyncio -from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, +from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose, FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) import os @@ -345,8 +345,7 @@ def list_config_llm_models() -> Dict[str, Dict]: return [(model_name, config_type), ...] ''' workers = list(FSCHAT_MODEL_WORKERS) - if LLM_MODEL not in workers: - workers.insert(0, LLM_MODEL) + return { "local": MODEL_PATH["llm_model"], "online": ONLINE_LLM_MODEL, @@ -431,7 +430,7 @@ def fschat_controller_address() -> str: return f"http://{host}:{port}" -def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: +def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str: if model := get_model_worker_config(model_name): # TODO: depends fastchat host = model["host"] if host == "0.0.0.0": @@ -660,7 +659,7 @@ def get_server_configs() -> Dict: TEXT_SPLITTER_NAME, ) from configs.model_config import ( - LLM_MODEL, + LLM_MODELS, HISTORY_LEN, TEMPERATURE, ) diff --git a/startup.py b/startup.py index 0f0997db..88840aab 100644 --- a/startup.py +++ b/startup.py @@ -22,7 +22,7 @@ from configs import ( LOG_PATH, log_verbose, logger, - LLM_MODEL, + LLM_MODELS, EMBEDDING_MODEL, TEXT_SPLITTER_NAME, FSCHAT_CONTROLLER, @@ -359,7 +359,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None): def run_model_worker( - model_name: str = LLM_MODEL, + model_name: str = LLM_MODELS[0], controller_address: str = "", log_level: str = "INFO", q: mp.Queue = None, @@ -496,7 +496,7 @@ def parse_args() -> argparse.ArgumentParser: "--model-worker", action="store_true", help="run fastchat's model_worker server with specified model name. " - "specify --model-name if not using default LLM_MODEL", + "specify --model-name if not using default LLM_MODELS", dest="model_worker", ) parser.add_argument( @@ -504,7 +504,7 @@ def parse_args() -> argparse.ArgumentParser: "--model-name", type=str, nargs="+", - default=[LLM_MODEL], + default=LLM_MODELS, help="specify model name for model worker. " "add addition names with space seperated to start multiple model workers.", dest="model_name", @@ -568,7 +568,7 @@ def dump_server_info(after_start=False, args=None): print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - models = [LLM_MODEL] + models = LLM_MODELS if args and args.model_name: models = args.model_name @@ -694,8 +694,8 @@ async def start_main_server(): processes["model_worker"][model_name] = process if args.api_worker: - configs = get_all_model_worker_configs() - for model_name, config in configs.items(): + for model_name in args.model_name: + config = get_model_worker_config(model_name) if (config.get("online_api") and config.get("worker_class") and model_name in FSCHAT_MODEL_WORKERS): diff --git a/tests/agent/test_agent_function.py b/tests/agent/test_agent_function.py index e860cb7a..27e5ae04 100644 --- a/tests/agent/test_agent_function.py +++ b/tests/agent/test_agent_function.py @@ -1,7 +1,7 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -from configs import LLM_MODEL, TEMPERATURE +from configs import LLM_MODELS, TEMPERATURE from server.utils import get_ChatOpenAI from langchain.chains import LLMChain from langchain.agents import LLMSingleActionAgent, AgentExecutor @@ -10,7 +10,7 @@ from langchain.memory import ConversationBufferWindowMemory memory = ConversationBufferWindowMemory(k=5) model = get_ChatOpenAI( - model_name=LLM_MODEL, + model_name=LLM_MODELS[0], temperature=TEMPERATURE, ) from server.agent.custom_template import CustomOutputParser, prompt diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py index 9b9a4a66..a04c9776 100644 --- a/tests/api/test_llm_api.py +++ b/tests/api/test_llm_api.py @@ -6,7 +6,6 @@ from pathlib import Path root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from configs.server_config import FSCHAT_MODEL_WORKERS -from configs.model_config import LLM_MODEL from server.utils import api_address, get_model_worker_config from pprint import pprint diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 80b855bb..b4a520ef 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -8,7 +8,7 @@ from pathlib import Path from configs import ( EMBEDDING_MODEL, DEFAULT_VS_TYPE, - LLM_MODEL, + LLM_MODELS, TEMPERATURE, SCORE_THRESHOLD, CHUNK_SIZE, @@ -259,7 +259,7 @@ class ApiRequest: self, messages: List[Dict], stream: bool = True, - model: str = LLM_MODEL, + model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, **kwargs: Any, @@ -291,7 +291,7 @@ class ApiRequest: query: str, history: List[Dict] = [], stream: bool = True, - model: str = LLM_MODEL, + model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", @@ -321,7 +321,7 @@ class ApiRequest: query: str, history: List[Dict] = [], stream: bool = True, - model: str = LLM_MODEL, + model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", @@ -353,7 +353,7 @@ class ApiRequest: score_threshold: float = SCORE_THRESHOLD, history: List[Dict] = [], stream: bool = True, - model: str = LLM_MODEL, + model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", @@ -391,7 +391,7 @@ class ApiRequest: top_k: int = SEARCH_ENGINE_TOP_K, history: List[Dict] = [], stream: bool = True, - model: str = LLM_MODEL, + model: str = LLM_MODELS[0], temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", @@ -677,9 +677,10 @@ class ApiRequest: return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", [])) - def get_default_llm_model(self) -> Tuple[str, bool]: + def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]: ''' - 从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回 + 从服务器上获取当前运行的LLM模型。 + 当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。 返回类型为(model_name, is_local_model) ''' def ret_sync(): @@ -687,26 +688,42 @@ class ApiRequest: if not running_models: return "", False - if LLM_MODEL in running_models: - return LLM_MODEL, True + model = "" + for m in LLM_MODELS: + if m not in running_models: + continue + is_local = not running_models[m].get("online_api") + if local_first and not is_local: + continue + else: + model = m + break - local_models = [k for k, v in running_models.items() if not v.get("online_api")] - if local_models: - return local_models[0], True - return list(running_models)[0], False + if not model: # LLM_MODELS中配置的模型都不在running_models里 + model = list(running_models)[0] + is_local = not running_models[model].get("online_api") + return model, is_local async def ret_async(): running_models = await self.list_running_models() if not running_models: return "", False - if LLM_MODEL in running_models: - return LLM_MODEL, True + model = "" + for m in LLM_MODELS: + if m not in running_models: + continue + is_local = not running_models[m].get("online_api") + if local_first and not is_local: + continue + else: + model = m + break - local_models = [k for k, v in running_models.items() if not v.get("online_api")] - if local_models: - return local_models[0], True - return list(running_models)[0], False + if not model: # LLM_MODELS中配置的模型都不在running_models里 + model = list(running_models)[0] + is_local = not running_models[model].get("online_api") + return model, is_local if self._use_async: return ret_async()