diff --git a/requirements_lite.txt b/requirements_lite.txt new file mode 100644 index 00000000..127caa14 --- /dev/null +++ b/requirements_lite.txt @@ -0,0 +1,60 @@ +langchain>=0.0.319 +fschat>=0.2.31 +openai +# sentence_transformers +# transformers>=4.33.0 +# torch>=2.0.1 +# torchvision +# torchaudio +fastapi>=0.103.1 +nltk~=3.8.1 +uvicorn~=0.23.1 +starlette~=0.27.0 +pydantic~=1.10.11 +# unstructured[all-docs]>=0.10.4 +# python-magic-bin; sys_platform == 'win32' +SQLAlchemy==2.0.19 +# faiss-cpu +# accelerate +# spacy +# PyMuPDF==1.22.5 +# rapidocr_onnxruntime>=1.3.2 + +requests +pathlib +pytest +# scikit-learn +# numexpr +# vllm==0.1.7; sys_platform == "linux" +# online api libs +zhipuai +dashscope>=1.10.0 # qwen +# qianfan +# volcengine>=1.0.106 # fangzhou + +# uncomment libs if you want to use corresponding vector store +# pymilvus==2.1.3 # requires milvus==2.1.3 +# psycopg2 +# pgvector + +numpy~=1.24.4 +pandas~=2.0.3 +streamlit>=1.26.0 +streamlit-option-menu>=0.3.6 +streamlit-antd-components>=0.1.11 +streamlit-chatbox>=1.1.9 +streamlit-aggrid>=0.3.4.post3 +httpx~=0.24.1 +watchdog +tqdm +websockets +# tiktoken +einops +# scipy +# transformers_stream_generator==0.0.4 + +# search engine libs +duckduckgo-search +metaphor-python +strsimpy +markdownify diff --git a/server/agent/tools/search_internet.py b/server/agent/tools/search_internet.py index de57d31f..5266efc9 100644 --- a/server/agent/tools/search_internet.py +++ b/server/agent/tools/search_internet.py @@ -1,5 +1,5 @@ import json -from server.chat import search_engine_chat +from server.chat.search_engine_chat import search_engine_chat from configs import VECTOR_SEARCH_TOP_K import asyncio from server.agent import model_container diff --git a/server/api.py b/server/api.py index 35f9f7cb..d47e786e 100644 --- a/server/api.py +++ b/server/api.py @@ -9,19 +9,19 @@ from configs.model_config import NLTK_DATA_PATH from configs.server_config import OPEN_CROSS_DOMAIN import argparse import uvicorn +from fastapi import Body from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse -from server.chat import (completion,chat, knowledge_base_chat, openai_chat, - search_engine_chat, agent_chat) -from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb -from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, - update_docs, download_doc, recreate_vector_store, - search_docs, DocumentWithScore, update_info) +from server.chat.chat import chat +from server.chat.openai_chat import openai_chat +from server.chat.search_engine_chat import search_engine_chat +from server.chat.completion import completion from server.llm_api import (list_running_models, list_config_models, change_llm_model, stop_llm_model, get_model_config, list_search_engines) -from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs -from typing import List +from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, + get_server_configs, get_prompt_template) +from typing import List, Literal nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -30,7 +30,7 @@ async def document(): return RedirectResponse(url="/docs") -def create_app(): +def create_app(run_mode: str = None): app = FastAPI( title="Langchain-Chatchat API Server", version=VERSION @@ -47,7 +47,13 @@ def create_app(): allow_methods=["*"], allow_headers=["*"], ) + mount_basic_routes(app) + if run_mode != "lite": + mount_knowledge_routes(app) + return app + +def mount_basic_routes(app: FastAPI): app.get("/", response_model=BaseResponse, summary="swagger 文档")(document) @@ -65,14 +71,69 @@ def create_app(): tags=["Chat"], summary="与llm模型对话(通过LLMChain)")(chat) - app.post("/chat/knowledge_base_chat", - tags=["Chat"], - summary="与知识库对话")(knowledge_base_chat) - app.post("/chat/search_engine_chat", tags=["Chat"], summary="与搜索引擎对话")(search_engine_chat) + # LLM模型相关接口 + app.post("/llm_model/list_running_models", + tags=["LLM Model Management"], + summary="列出当前已加载的模型", + )(list_running_models) + + app.post("/llm_model/list_config_models", + tags=["LLM Model Management"], + summary="列出configs已配置的模型", + )(list_config_models) + + app.post("/llm_model/get_model_config", + tags=["LLM Model Management"], + summary="获取模型配置(合并后)", + )(get_model_config) + + app.post("/llm_model/stop", + tags=["LLM Model Management"], + summary="停止指定的LLM模型(Model Worker)", + )(stop_llm_model) + + app.post("/llm_model/change", + tags=["LLM Model Management"], + summary="切换指定的LLM模型(Model Worker)", + )(change_llm_model) + + # 服务器相关接口 + app.post("/server/configs", + tags=["Server State"], + summary="获取服务器原始配置信息", + )(get_server_configs) + + app.post("/server/list_search_engines", + tags=["Server State"], + summary="获取服务器支持的搜索引擎", + )(list_search_engines) + + @app.post("/server/get_prompt_template", + tags=["Server State"], + summary="获取服务区配置的 prompt 模板") + def get_server_prompt_template( + type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat,search_engine_chat,agent_chat"), + name: str = Body("default", description="模板名称"), + ) -> str: + return get_prompt_template(type=type, name=name) + + +def mount_knowledge_routes(app: FastAPI): + from server.chat.knowledge_base_chat import knowledge_base_chat + from server.chat.agent_chat import agent_chat + from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb + from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, + update_docs, download_doc, recreate_vector_store, + search_docs, DocumentWithScore, update_info) + + app.post("/chat/knowledge_base_chat", + tags=["Chat"], + summary="与知识库对话")(knowledge_base_chat) + app.post("/chat/agent_chat", tags=["Chat"], summary="与agent对话")(agent_chat) @@ -139,48 +200,6 @@ def create_app(): summary="根据content中文档重建向量库,流式输出处理进度。" )(recreate_vector_store) - # LLM模型相关接口 - app.post("/llm_model/list_running_models", - tags=["LLM Model Management"], - summary="列出当前已加载的模型", - )(list_running_models) - - app.post("/llm_model/list_config_models", - tags=["LLM Model Management"], - summary="列出configs已配置的模型", - )(list_config_models) - - app.post("/llm_model/get_model_config", - tags=["LLM Model Management"], - summary="获取模型配置(合并后)", - )(get_model_config) - - app.post("/llm_model/stop", - tags=["LLM Model Management"], - summary="停止指定的LLM模型(Model Worker)", - )(stop_llm_model) - - app.post("/llm_model/change", - tags=["LLM Model Management"], - summary="切换指定的LLM模型(Model Worker)", - )(change_llm_model) - - # 服务器相关接口 - app.post("/server/configs", - tags=["Server State"], - summary="获取服务器原始配置信息", - )(get_server_configs) - - app.post("/server/list_search_engines", - tags=["Server State"], - summary="获取服务器支持的搜索引擎", - )(list_search_engines) - - return app - - -app = create_app() - def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): @@ -205,6 +224,10 @@ if __name__ == "__main__": # 初始化消息 args = parser.parse_args() args_dict = vars(args) + + app = create_app() + mount_knowledge_routes(app) + run_api(host=args.host, port=args.port, ssl_keyfile=args.ssl_keyfile, diff --git a/server/chat/__init__.py b/server/chat/__init__.py index 294aeab7..e69de29b 100644 --- a/server/chat/__init__.py +++ b/server/chat/__init__.py @@ -1,6 +0,0 @@ -from .chat import chat -from .completion import completion -from .knowledge_base_chat import knowledge_base_chat -from .openai_chat import openai_chat -from .search_engine_chat import search_engine_chat -from .agent_chat import agent_chat \ No newline at end of file diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index e1ccaa48..5ae82985 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -1,4 +1,5 @@ -from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper +from langchain.utilities.bing_search import BingSearchAPIWrapper +from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE, TEXT_SPLITTER_NAME, OVERLAP_SIZE) @@ -12,13 +13,16 @@ from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts.chat import ChatPromptTemplate +from langchain.text_splitter import RecursiveCharacterTextSplitter from typing import List, Optional, Dict from server.chat.utils import History from langchain.docstore.document import Document import json +from strsimpy.normalized_levenshtein import NormalizedLevenshtein +from markdownify import markdownify -def bing_search(text, result_len=SEARCH_ENGINE_TOP_K): +def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs): if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY): return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", "title": "env info is not found", @@ -28,7 +32,7 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K): return search.results(text, result_len) -def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K): +def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs): search = DuckDuckGoSearchAPIWrapper() return search.results(text, result_len) @@ -36,40 +40,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K): def metaphor_search( text: str, result_len: int = SEARCH_ENGINE_TOP_K, - splitter_name: str = "SpacyTextSplitter", + split_result: bool = False, chunk_size: int = 500, chunk_overlap: int = OVERLAP_SIZE, ) -> List[Dict]: from metaphor_python import Metaphor - from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool - from server.knowledge_base.utils import make_text_splitter if not METAPHOR_API_KEY: return [] - + client = Metaphor(METAPHOR_API_KEY) search = client.search(text, num_results=result_len, use_autoprompt=True) contents = search.get_contents().contents + for x in contents: + x.extract = markdownify(x.extract) # metaphor 返回的内容都是长文本,需要分词再检索 - docs = [Document(page_content=x.extract, - metadata={"link": x.url, "title": x.title}) - for x in contents] - text_splitter = make_text_splitter(splitter_name=splitter_name, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap) - splitted_docs = text_splitter.split_documents(docs) - - # 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档 - if len(splitted_docs) > result_len: - vs = memo_faiss_pool.new_vector_store() - vs.add_documents(splitted_docs) - splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0) + if split_result: + docs = [Document(page_content=x.extract, + metadata={"link": x.url, "title": x.title}) + for x in contents] + text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap) + splitted_docs = text_splitter.split_documents(docs) + + # 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档 + if len(splitted_docs) > result_len: + normal = NormalizedLevenshtein() + for x in splitted_docs: + x.metadata["score"] = normal.similarity(text, x.page_content) + splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True) + splitted_docs = splitted_docs[:result_len] + + docs = [{"snippet": x.page_content, + "link": x.metadata["link"], + "title": x.metadata["title"]} + for x in splitted_docs] + else: + docs = [{"snippet": x.extract, + "link": x.url, + "title": x.title} + for x in contents] - docs = [{"snippet": x.page_content, - "link": x.metadata["link"], - "title": x.metadata["title"]} - for x in splitted_docs] return docs @@ -93,9 +106,10 @@ async def lookup_search_engine( query: str, search_engine_name: str, top_k: int = SEARCH_ENGINE_TOP_K, + split_result: bool = False, ): search_engine = SEARCH_ENGINES[search_engine_name] - results = await run_in_threadpool(search_engine, query, result_len=top_k) + results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result) docs = search_result2docs(results) return docs @@ -116,6 +130,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)") ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -140,7 +155,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", callbacks=[callback], ) - docs = await lookup_search_engine(query, search_engine_name, top_k) + docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result) context = "\n".join([doc.page_content for doc in docs]) prompt_template = get_prompt_template("search_engine_chat", prompt_name) diff --git a/server/model_workers/baichuan.py b/server/model_workers/baichuan.py index 1c6a6f1d..ffb741c5 100644 --- a/server/model_workers/baichuan.py +++ b/server/model_workers/baichuan.py @@ -4,6 +4,8 @@ import json import time import hashlib + +from fastchat.conversation import Conversation from server.model_workers.base import ApiModelWorker from server.utils import get_model_worker_config, get_httpx_client from fastchat import conversation as conv @@ -78,16 +80,6 @@ class BaiChuanWorker(ApiModelWorker): kwargs.setdefault("context_len", 32768) super().__init__(**kwargs) - # TODO: 确认模板是否需要修改 - self.conv = conv.Conversation( - name=self.model_names[0], - system_message="", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - config = self.get_config() self.version = config.get("version",version) self.api_key = config.get("api_key") @@ -127,6 +119,18 @@ class BaiChuanWorker(ApiModelWorker): print("embedding") print(params) + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + # TODO: 确认模板是否需要修改 + return conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + + if __name__ == "__main__": import uvicorn from server.utils import MakeFastAPIOffline diff --git a/server/model_workers/base.py b/server/model_workers/base.py index 7319879c..d4632bf1 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -1,3 +1,4 @@ +from fastchat.conversation import Conversation from configs.basic_config import LOG_PATH import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH @@ -61,6 +62,9 @@ class ApiModelWorker(BaseModelWorker): print("embedding") # print(params) + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + raise NotImplementedError + # help methods def get_config(self): from server.utils import get_model_worker_config diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py index 33a6b7da..8243b188 100644 --- a/server/model_workers/fangzhou.py +++ b/server/model_workers/fangzhou.py @@ -1,3 +1,4 @@ +from fastchat.conversation import Conversation from server.model_workers.base import ApiModelWorker from configs.model_config import TEMPERATURE from fastchat import conversation as conv @@ -74,15 +75,6 @@ class FangZhouWorker(ApiModelWorker): self.api_key = config.get("api_key") self.secret_key = config.get("secret_key") - self.conv = conv.Conversation( - name=self.model_names[0], - system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", - messages=[], - roles=["user", "assistant", "system"], - sep="\n### ", - stop_str="###", - ) - def generate_stream_gate(self, params): super().generate_stream_gate(params) @@ -107,6 +99,16 @@ class FangZhouWorker(ApiModelWorker): print("embedding") print(params) + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + return conv.Conversation( + name=self.model_names[0], + system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + messages=[], + roles=["user", "assistant", "system"], + sep="\n### ", + stop_str="###", + ) + if __name__ == "__main__": import uvicorn diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 9079ea44..54eaf07e 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -1,3 +1,4 @@ +from fastchat.conversation import Conversation from server.model_workers.base import ApiModelWorker from fastchat import conversation as conv import sys @@ -22,16 +23,6 @@ class MiniMaxWorker(ApiModelWorker): kwargs.setdefault("context_len", 16384) super().__init__(**kwargs) - # TODO: 确认模板是否需要修改 - self.conv = conv.Conversation( - name=self.model_names[0], - system_message="", - messages=[], - roles=["USER", "BOT"], - sep="\n### ", - stop_str="###", - ) - def prompt_to_messages(self, prompt: str) -> List[Dict]: result = super().prompt_to_messages(prompt) messages = [{"sender_type": x["role"], "text": x["content"]} for x in result] @@ -86,6 +77,17 @@ class MiniMaxWorker(ApiModelWorker): print("embedding") print(params) + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + # TODO: 确认模板是否需要修改 + return conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["USER", "BOT"], + sep="\n### ", + stop_str="###", + ) + if __name__ == "__main__": import uvicorn diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 5eefd407..4273a815 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -1,3 +1,4 @@ +from fastchat.conversation import Conversation from server.model_workers.base import ApiModelWorker from configs.model_config import TEMPERATURE from fastchat import conversation as conv @@ -120,16 +121,6 @@ class QianFanWorker(ApiModelWorker): kwargs.setdefault("context_len", 16384) super().__init__(**kwargs) - # TODO: 确认模板是否需要修改 - self.conv = conv.Conversation( - name=self.model_names[0], - system_message="", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - config = self.get_config() self.version = version self.api_key = config.get("api_key") @@ -162,6 +153,17 @@ class QianFanWorker(ApiModelWorker): print("embedding") print(params) + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + # TODO: 确认模板是否需要修改 + return conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + if __name__ == "__main__": import uvicorn diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py index 32d87574..ab51b8a1 100644 --- a/server/model_workers/qwen.py +++ b/server/model_workers/qwen.py @@ -1,5 +1,7 @@ import json import sys + +from fastchat.conversation import Conversation from configs import TEMPERATURE from http import HTTPStatus from typing import List, Literal, Dict @@ -68,15 +70,6 @@ class QwenWorker(ApiModelWorker): kwargs.setdefault("context_len", 16384) super().__init__(**kwargs) - # TODO: 确认模板是否需要修改 - self.conv = conv.Conversation( - name=self.model_names[0], - system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", - messages=[], - roles=["user", "assistant", "system"], - sep="\n### ", - stop_str="###", - ) config = self.get_config() self.api_key = config.get("api_key") self.version = version @@ -108,6 +101,17 @@ class QwenWorker(ApiModelWorker): print("embedding") print(params) + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + # TODO: 确认模板是否需要修改 + return conv.Conversation( + name=self.model_names[0], + system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + messages=[], + roles=["user", "assistant", "system"], + sep="\n### ", + stop_str="###", + ) + if __name__ == "__main__": import uvicorn diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index bc98a9cf..3091c949 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -1,3 +1,4 @@ +from fastchat.conversation import Conversation from server.model_workers.base import ApiModelWorker from fastchat import conversation as conv import sys @@ -38,16 +39,6 @@ class XingHuoWorker(ApiModelWorker): kwargs.setdefault("context_len", 8192) super().__init__(**kwargs) - # TODO: 确认模板是否需要修改 - self.conv = conv.Conversation( - name=self.model_names[0], - system_message="", - messages=[], - roles=["user", "assistant"], - sep="\n### ", - stop_str="###", - ) - def generate_stream_gate(self, params): # TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接 @@ -86,6 +77,17 @@ class XingHuoWorker(ApiModelWorker): print("embedding") print(params) + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + # TODO: 确认模板是否需要修改 + return conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + if __name__ == "__main__": import uvicorn diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 321f01f6..34e1396a 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -1,3 +1,4 @@ +from fastchat.conversation import Conversation from server.model_workers.base import ApiModelWorker from fastchat import conversation as conv import sys @@ -23,16 +24,6 @@ class ChatGLMWorker(ApiModelWorker): super().__init__(**kwargs) self.version = version - # 这里的是chatglm api的模板,其它API的conv_template需要定制 - self.conv = conv.Conversation( - name=self.model_names[0], - system_message="你是一个聪明的助手,请根据用户的提示来完成任务", - messages=[], - roles=["Human", "Assistant"], - sep="\n###", - stop_str="###", - ) - def generate_stream_gate(self, params): # TODO: 维护request_id import zhipuai @@ -59,6 +50,17 @@ class ChatGLMWorker(ApiModelWorker): print("embedding") # print(params) + def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: + # 这里的是chatglm api的模板,其它API的conv_template需要定制 + return conv.Conversation( + name=self.model_names[0], + system_message="你是一个聪明的助手,请根据用户的提示来完成任务", + messages=[], + roles=["Human", "Assistant"], + sep="\n###", + stop_str="###", + ) + if __name__ == "__main__": import uvicorn diff --git a/server/utils.py b/server/utils.py index 8e74a9c9..55a38725 100644 --- a/server/utils.py +++ b/server/utils.py @@ -477,7 +477,7 @@ def fschat_controller_address() -> str: def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: - if model := get_model_worker_config(model_name): + if model := get_model_worker_config(model_name): # TODO: depends fastchat host = model["host"] if host == "0.0.0.0": host = "127.0.0.1" diff --git a/startup.py b/startup.py index a51c8917..710b811e 100644 --- a/startup.py +++ b/startup.py @@ -423,13 +423,13 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None): uvicorn.run(app, host=host, port=port) -def run_api_server(started_event: mp.Event = None): +def run_api_server(started_event: mp.Event = None, run_mode: str = None): from server.api import create_app import uvicorn from server.utils import set_httpx_config set_httpx_config() - app = create_app() + app = create_app(run_mode=run_mode) _set_app_event(app, started_event) host = API_SERVER["host"] @@ -438,21 +438,27 @@ def run_api_server(started_event: mp.Event = None): uvicorn.run(app, host=host, port=port) -def run_webui(started_event: mp.Event = None): +def run_webui(started_event: mp.Event = None, run_mode: str = None): from server.utils import set_httpx_config set_httpx_config() host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] - p = subprocess.Popen(["streamlit", "run", "webui.py", - "--server.address", host, - "--server.port", str(port), - "--theme.base", "light", - "--theme.primaryColor", "#165dff", - "--theme.secondaryBackgroundColor", "#f5f5f5", - "--theme.textColor", "#000000", - ]) + cmd = ["streamlit", "run", "webui.py", + "--server.address", host, + "--server.port", str(port), + "--theme.base", "light", + "--theme.primaryColor", "#165dff", + "--theme.secondaryBackgroundColor", "#f5f5f5", + "--theme.textColor", "#000000", + ] + if run_mode == "lite": + cmd += [ + "--", + "lite", + ] + p = subprocess.Popen(cmd) started_event.set() p.wait() @@ -535,6 +541,13 @@ def parse_args() -> argparse.ArgumentParser: help="减少fastchat服务log信息", dest="quiet", ) + parser.add_argument( + "-i", + "--lite", + action="store_true", + help="以Lite模式运行:仅支持在线API的LLM对话、搜索引擎对话", + dest="lite", + ) args = parser.parse_args() return args, parser @@ -596,6 +609,7 @@ async def start_main_server(): mp.set_start_method("spawn") manager = mp.Manager() + run_mode = None queue = manager.Queue() args, parser = parse_args() @@ -621,6 +635,10 @@ async def start_main_server(): args.api = False args.webui = False + if args.lite: + args.model_worker = False + run_mode = "lite" + dump_server_info(args=args) if len(sys.argv) > 1: @@ -698,7 +716,7 @@ async def start_main_server(): process = Process( target=run_api_server, name=f"API Server", - kwargs=dict(started_event=api_started), + kwargs=dict(started_event=api_started, run_mode=run_mode), daemon=True, ) processes["api"] = process @@ -708,7 +726,7 @@ async def start_main_server(): process = Process( target=run_webui, name=f"WEBUI Server", - kwargs=dict(started_event=webui_started), + kwargs=dict(started_event=webui_started, run_mode=run_mode), daemon=True, ) processes["webui"] = process diff --git a/tests/api/test_server_state_api.py b/tests/api/test_server_state_api.py new file mode 100644 index 00000000..59c0985f --- /dev/null +++ b/tests/api/test_server_state_api.py @@ -0,0 +1,47 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from webui_pages.utils import ApiRequest + +import pytest +from pprint import pprint +from typing import List + + +api = ApiRequest() + + +def test_get_default_llm(): + llm = api.get_default_llm_model() + + print(llm) + assert isinstance(llm, tuple) + assert isinstance(llm[0], str) and isinstance(llm[1], bool) + + +def test_server_configs(): + configs = api.get_server_configs() + pprint(configs, depth=2) + + assert isinstance(configs, dict) + assert len(configs) > 0 + + +def test_list_search_engines(): + engines = api.list_search_engines() + pprint(engines) + + assert isinstance(engines, list) + assert len(engines) > 0 + + +@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"]) +def test_get_prompt_template(type): + print(f"prompt template for: {type}") + template = api.get_prompt_template(type=type) + + print(template) + assert isinstance(template, str) + assert len(template) > 0 diff --git a/tests/api/test_stream_chat_api_thread.py b/tests/api/test_stream_chat_api_thread.py new file mode 100644 index 00000000..bdf74367 --- /dev/null +++ b/tests/api/test_stream_chat_api_thread.py @@ -0,0 +1,81 @@ +import requests +import json +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from configs import BING_SUBSCRIPTION_KEY +from server.utils import api_address + +from pprint import pprint +from concurrent.futures import ThreadPoolExecutor, as_completed +import time + + +api_base_url = api_address() + + +def dump_input(d, title): + print("\n") + print("=" * 30 + title + " input " + "="*30) + pprint(d) + + +def dump_output(r, title): + print("\n") + print("=" * 30 + title + " output" + "="*30) + for line in r.iter_content(None, decode_unicode=True): + print(line, end="", flush=True) + + +headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json', +} + + +def knowledge_chat(api="/chat/knowledge_base_chat"): + url = f"{api_base_url}{api}" + data = { + "query": "如何提问以获得高质量答案", + "knowledge_base_name": "samples", + "history": [ + { + "role": "user", + "content": "你好" + }, + { + "role": "assistant", + "content": "你好,我是 ChatGLM" + } + ], + "stream": True + } + result = [] + response = requests.post(url, headers=headers, json=data, stream=True) + + for line in response.iter_content(None, decode_unicode=True): + data = json.loads(line) + result.append(data) + + return result + + +def test_thread(): + threads = [] + times = [] + pool = ThreadPoolExecutor() + start = time.time() + for i in range(10): + t = pool.submit(knowledge_chat) + threads.append(t) + + for r in as_completed(threads): + end = time.time() + times.append(end - start) + print("\nResult:\n") + pprint(r.result()) + + print("\nTime used:\n") + for x in times: + print(f"{x}") diff --git a/webui.py b/webui.py index 85d6cb40..fe339390 100644 --- a/webui.py +++ b/webui.py @@ -1,8 +1,9 @@ import streamlit as st from webui_pages.utils import * from streamlit_option_menu import option_menu -from webui_pages import * +from webui_pages.dialogue.dialogue import dialogue_page, chat_box import os +import sys from configs import VERSION from server.utils import api_address @@ -10,6 +11,8 @@ from server.utils import api_address api = ApiRequest(base_url=api_address()) if __name__ == "__main__": + is_lite = "lite" in sys.argv + st.set_page_config( "Langchain-Chatchat WebUI", os.path.join("img", "chatchat_icon_blue_square_v2.png"), @@ -26,11 +29,15 @@ if __name__ == "__main__": "icon": "chat", "func": dialogue_page, }, - "知识库管理": { - "icon": "hdd-stack", - "func": knowledge_base_page, - }, } + if not is_lite: + from webui_pages.knowledge_base.knowledge_base import knowledge_base_page + pages.update({ + "知识库管理": { + "icon": "hdd-stack", + "func": knowledge_base_page, + }, + }) with st.sidebar: st.image( @@ -57,4 +64,4 @@ if __name__ == "__main__": ) if selected_page in pages: - pages[selected_page]["func"](api) + pages[selected_page]["func"](api=api, is_lite=is_lite) diff --git a/webui_pages/__init__.py b/webui_pages/__init__.py index 064c36be..e69de29b 100644 --- a/webui_pages/__init__.py +++ b/webui_pages/__init__.py @@ -1,3 +0,0 @@ -from .dialogue import dialogue_page, chat_box -from .knowledge_base import knowledge_base_page -from .model_config import model_config_page \ No newline at end of file diff --git a/webui_pages/dialogue/__init__.py b/webui_pages/dialogue/__init__.py index 3a95d43c..e69de29b 100644 --- a/webui_pages/dialogue/__init__.py +++ b/webui_pages/dialogue/__init__.py @@ -1 +0,0 @@ -from .dialogue import dialogue_page, chat_box \ No newline at end of file diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 9267be21..1d4201d0 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -3,10 +3,11 @@ from webui_pages.utils import * from streamlit_chatbox import * from datetime import datetime import os -from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, +from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL) from typing import List, Dict + chat_box = ChatBox( assistant_avatar=os.path.join( "img", @@ -33,27 +34,9 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) -> return chat_box.filter_history(history_len=history_len, filter=filter) -def get_default_llm_model(api: ApiRequest) -> (str, bool): - ''' - 从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回 - 返回类型为(model_name, is_local_model) - ''' - running_models = api.list_running_models() - if not running_models: - return "", False - - if LLM_MODEL in running_models: - return LLM_MODEL, True - - local_models = [k for k, v in running_models.items() if not v.get("online_api")] - if local_models: - return local_models[0], True - return list(running_models)[0], False - - -def dialogue_page(api: ApiRequest): +def dialogue_page(api: ApiRequest, is_lite: bool = False): if not chat_box.chat_inited: - default_model = get_default_llm_model(api)[0] + default_model = api.get_default_llm_model()[0] st.toast( f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n" f"当前运行的模型`{default_model}`, 您可以开始提问了." @@ -70,13 +53,19 @@ def dialogue_page(api: ApiRequest): text = f"{text} 当前知识库: `{cur_kb}`。" st.toast(text) + if is_lite: + dialogue_modes = ["LLM 对话", + "搜索引擎问答", + ] + else: + dialogue_modes = ["LLM 对话", + "知识库问答", + "搜索引擎问答", + "自定义Agent问答", + ] dialogue_mode = st.selectbox("请选择对话模式:", - ["LLM 对话", - "知识库问答", - "搜索引擎问答", - "自定义Agent问答", - ], - index=3, + dialogue_modes, + index=0, on_change=on_mode_change, key="dialogue_mode", ) @@ -107,7 +96,7 @@ def dialogue_page(api: ApiRequest): for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型 available_models.append(k) llm_models = running_models + available_models - index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0])) + index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0])) llm_model = st.selectbox("选择LLM模型:", llm_models, index, @@ -116,9 +105,10 @@ def dialogue_page(api: ApiRequest): key="llm_model", ) if (st.session_state.get("prev_llm_model") != llm_model - and not llm_model in config_models.get("online", {}) - and not llm_model in config_models.get("langchain", {}) - and llm_model not in running_models): + and not is_lite + and not llm_model in config_models.get("online", {}) + and not llm_model in config_models.get("langchain", {}) + and llm_model not in running_models): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): prev_model = st.session_state.get("prev_llm_model") r = api.change_llm_model(prev_model, llm_model) @@ -128,12 +118,18 @@ def dialogue_page(api: ApiRequest): st.success(msg) st.session_state["prev_llm_model"] = llm_model - index_prompt = { - "LLM 对话": "llm_chat", - "自定义Agent问答": "agent_chat", - "搜索引擎问答": "search_engine_chat", - "知识库问答": "knowledge_base_chat", - } + if is_lite: + index_prompt = { + "LLM 对话": "llm_chat", + "搜索引擎问答": "search_engine_chat", + } + else: + index_prompt = { + "LLM 对话": "llm_chat", + "自定义Agent问答": "agent_chat", + "搜索引擎问答": "search_engine_chat", + "知识库问答": "knowledge_base_chat", + } prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys()) prompt_template_name = prompt_templates_kb_list[0] if "prompt_template_select" not in st.session_state: @@ -284,7 +280,8 @@ def dialogue_page(api: ApiRequest): history=history, model=llm_model, prompt_name=prompt_template_name, - temperature=temperature): + temperature=temperature, + split_result=se_top_k>1): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) elif chunk := d.get("answer"): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 8190dba1..7a34f56e 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -20,12 +20,11 @@ from configs import ( logger, log_verbose, ) import httpx -from server.chat.openai_chat import OpenAiChatMsgIn import contextlib import json import os from io import BytesIO -from server.utils import run_async, set_httpx_config, api_address, get_httpx_client +from server.utils import set_httpx_config, api_address, get_httpx_client from pprint import pprint @@ -213,7 +212,7 @@ class ApiRequest: if log_verbose: logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) - return {"code": 500, "msg": msg} + return {"code": 500, "msg": msg, "data": None} if value_func is None: value_func = (lambda r: r) @@ -233,9 +232,26 @@ class ApiRequest: return value_func(response) # 服务器信息 - def get_server_configs(self, **kwargs): + def get_server_configs(self, **kwargs) -> Dict: response = self.post("/server/configs", **kwargs) - return self._get_response_value(response, lambda r: r.json()) + return self._get_response_value(response, as_json=True) + + def list_search_engines(self, **kwargs) -> List: + response = self.post("/server/list_search_engines", **kwargs) + return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"]) + + def get_prompt_template( + self, + type: str = "llm_chat", + name: str = "default", + **kwargs, + ) -> str: + data = { + "type": type, + "name": name, + } + response = self.post("/server/get_prompt_template", json=data, **kwargs) + return self._get_response_value(response, value_func=lambda r: r.text) # 对话相关操作 @@ -251,16 +267,14 @@ class ApiRequest: ''' 对应api.py/chat/fastchat接口 ''' - msg = OpenAiChatMsgIn(**{ + data = { "messages": messages, "stream": stream, "model": model, "temperature": temperature, "max_tokens": max_tokens, - **kwargs, - }) + } - data = msg.dict(exclude_unset=True, exclude_none=True) print(f"received input message:") pprint(data) @@ -268,6 +282,7 @@ class ApiRequest: "/chat/fastchat", json=data, stream=True, + **kwargs, ) return self._httpx_stream2generator(response) @@ -380,6 +395,7 @@ class ApiRequest: temperature: float = TEMPERATURE, max_tokens: int = None, prompt_name: str = "default", + split_result: bool = False, ): ''' 对应api.py/chat/search_engine_chat接口 @@ -394,6 +410,7 @@ class ApiRequest: "temperature": temperature, "max_tokens": max_tokens, "prompt_name": prompt_name, + "split_result": split_result, } print(f"received input message:") @@ -659,6 +676,43 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", [])) + + def get_default_llm_model(self) -> (str, bool): + ''' + 从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回 + 返回类型为(model_name, is_local_model) + ''' + def ret_sync(): + running_models = self.list_running_models() + if not running_models: + return "", False + + if LLM_MODEL in running_models: + return LLM_MODEL, True + + local_models = [k for k, v in running_models.items() if not v.get("online_api")] + if local_models: + return local_models[0], True + return list(running_models)[0], False + + async def ret_async(): + running_models = await self.list_running_models() + if not running_models: + return "", False + + if LLM_MODEL in running_models: + return LLM_MODEL, True + + local_models = [k for k, v in running_models.items() if not v.get("online_api")] + if local_models: + return local_models[0], True + return list(running_models)[0], False + + if self._use_async: + return ret_async() + else: + return ret_sync() + def list_config_models(self) -> Dict[str, List[str]]: ''' 获取服务器configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。