支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话 (#1860)

* move get_default_llm_model from webui to ApiRequest

增加API接口及其测试用例:
- /server/get_prompt_template: 获取服务器配置的 prompt 模板
- 增加知识库多线程访问测试用例

支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话

* fix bug in server.api

---------

Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
This commit is contained in:
liunux4odoo 2023-10-25 08:30:23 +08:00 committed by GitHub
parent be67ea43d8
commit 03e55e11c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 541 additions and 227 deletions

60
requirements_lite.txt Normal file
View File

@ -0,0 +1,60 @@
langchain>=0.0.319
fschat>=0.2.31
openai
# sentence_transformers
# transformers>=4.33.0
# torch>=2.0.1
# torchvision
# torchaudio
fastapi>=0.103.1
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
pydantic~=1.10.11
# unstructured[all-docs]>=0.10.4
# python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
# faiss-cpu
# accelerate
# spacy
# PyMuPDF==1.22.5
# rapidocr_onnxruntime>=1.3.2
requests
pathlib
pytest
# scikit-learn
# numexpr
# vllm==0.1.7; sys_platform == "linux"
# online api libs
zhipuai
dashscope>=1.10.0 # qwen
# qianfan
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2
# pgvector
numpy~=1.24.4
pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.9
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog
tqdm
websockets
# tiktoken
einops
# scipy
# transformers_stream_generator==0.0.4
# search engine libs
duckduckgo-search
metaphor-python
strsimpy
markdownify

View File

@ -1,5 +1,5 @@
import json
from server.chat import search_engine_chat
from server.chat.search_engine_chat import search_engine_chat
from configs import VECTOR_SEARCH_TOP_K
import asyncio
from server.agent import model_container

View File

@ -9,19 +9,19 @@ from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (completion,chat, knowledge_base_chat, openai_chat,
search_engine_chat, agent_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore, update_info)
from server.chat.chat import chat
from server.chat.openai_chat import openai_chat
from server.chat.search_engine_chat import search_engine_chat
from server.chat.completion import completion
from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model,
get_model_config, list_search_engines)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs
from typing import List
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
get_server_configs, get_prompt_template)
from typing import List, Literal
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -30,7 +30,7 @@ async def document():
return RedirectResponse(url="/docs")
def create_app():
def create_app(run_mode: str = None):
app = FastAPI(
title="Langchain-Chatchat API Server",
version=VERSION
@ -47,7 +47,13 @@ def create_app():
allow_methods=["*"],
allow_headers=["*"],
)
mount_basic_routes(app)
if run_mode != "lite":
mount_knowledge_routes(app)
return app
def mount_basic_routes(app: FastAPI):
app.get("/",
response_model=BaseResponse,
summary="swagger 文档")(document)
@ -65,14 +71,69 @@ def create_app():
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)")(chat)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
app.post("/chat/search_engine_chat",
tags=["Chat"],
summary="与搜索引擎对话")(search_engine_chat)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/get_model_config",
tags=["LLM Model Management"],
summary="获取模型配置(合并后)",
)(get_model_config)
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)(stop_llm_model)
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)(change_llm_model)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],
summary="获取服务器原始配置信息",
)(get_server_configs)
app.post("/server/list_search_engines",
tags=["Server State"],
summary="获取服务器支持的搜索引擎",
)(list_search_engines)
@app.post("/server/get_prompt_template",
tags=["Server State"],
summary="获取服务区配置的 prompt 模板")
def get_server_prompt_template(
type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型可选值llm_chatknowledge_base_chatsearch_engine_chatagent_chat"),
name: str = Body("default", description="模板名称"),
) -> str:
return get_prompt_template(type=type, name=name)
def mount_knowledge_routes(app: FastAPI):
from server.chat.knowledge_base_chat import knowledge_base_chat
from server.chat.agent_chat import agent_chat
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore, update_info)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
app.post("/chat/agent_chat",
tags=["Chat"],
summary="与agent对话")(agent_chat)
@ -139,48 +200,6 @@ def create_app():
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/get_model_config",
tags=["LLM Model Management"],
summary="获取模型配置(合并后)",
)(get_model_config)
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)(stop_llm_model)
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)(change_llm_model)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],
summary="获取服务器原始配置信息",
)(get_server_configs)
app.post("/server/list_search_engines",
tags=["Server State"],
summary="获取服务器支持的搜索引擎",
)(list_search_engines)
return app
app = create_app()
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
@ -205,6 +224,10 @@ if __name__ == "__main__":
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
app = create_app()
mount_knowledge_routes(app)
run_api(host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,

View File

@ -1,6 +0,0 @@
from .chat import chat
from .completion import completion
from .knowledge_base_chat import knowledge_base_chat
from .openai_chat import openai_chat
from .search_engine_chat import search_engine_chat
from .agent_chat import agent_chat

View File

@ -1,4 +1,5 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
@ -12,13 +13,16 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Optional, Dict
from server.chat.utils import History
from langchain.docstore.document import Document
import json
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env info is not found",
@ -28,7 +32,7 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
return search.results(text, result_len)
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
search = DuckDuckGoSearchAPIWrapper()
return search.results(text, result_len)
@ -36,40 +40,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
def metaphor_search(
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
splitter_name: str = "SpacyTextSplitter",
split_result: bool = False,
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]:
from metaphor_python import Metaphor
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from server.knowledge_base.utils import make_text_splitter
if not METAPHOR_API_KEY:
return []
client = Metaphor(METAPHOR_API_KEY)
search = client.search(text, num_results=result_len, use_autoprompt=True)
contents = search.get_contents().contents
for x in contents:
x.extract = markdownify(x.extract)
# metaphor 返回的内容都是长文本,需要分词再检索
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = make_text_splitter(splitter_name=splitter_name,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
vs = memo_faiss_pool.new_vector_store()
vs.add_documents(splitted_docs)
splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0)
if split_result:
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
normal = NormalizedLevenshtein()
for x in splitted_docs:
x.metadata["score"] = normal.similarity(text, x.page_content)
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
else:
docs = [{"snippet": x.extract,
"link": x.url,
"title": x.title}
for x in contents]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
return docs
@ -93,9 +106,10 @@ async def lookup_search_engine(
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
):
search_engine = SEARCH_ENGINES[search_engine_name]
results = await run_in_threadpool(search_engine, query, result_len=top_k)
results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
docs = search_result2docs(results)
return docs
@ -116,6 +130,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -140,7 +155,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
callbacks=[callback],
)
docs = await lookup_search_engine(query, search_engine_name, top_k)
docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
context = "\n".join([doc.page_content for doc in docs])
prompt_template = get_prompt_template("search_engine_chat", prompt_name)

View File

@ -4,6 +4,8 @@
import json
import time
import hashlib
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config, get_httpx_client
from fastchat import conversation as conv
@ -78,16 +80,6 @@ class BaiChuanWorker(ApiModelWorker):
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.version = config.get("version",version)
self.api_key = config.get("api_key")
@ -127,6 +119,18 @@ class BaiChuanWorker(ApiModelWorker):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from configs.basic_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
@ -61,6 +62,9 @@ class ApiModelWorker(BaseModelWorker):
print("embedding")
# print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
raise NotImplementedError
# help methods
def get_config(self):
from server.utils import get_model_worker_config

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv
@ -74,15 +75,6 @@ class FangZhouWorker(ApiModelWorker):
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
super().generate_stream_gate(params)
@ -107,6 +99,16 @@ class FangZhouWorker(ApiModelWorker):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
@ -22,16 +23,6 @@ class MiniMaxWorker(ApiModelWorker):
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
def prompt_to_messages(self, prompt: str) -> List[Dict]:
result = super().prompt_to_messages(prompt)
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result]
@ -86,6 +77,17 @@ class MiniMaxWorker(ApiModelWorker):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv
@ -120,16 +121,6 @@ class QianFanWorker(ApiModelWorker):
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.version = version
self.api_key = config.get("api_key")
@ -162,6 +153,17 @@ class QianFanWorker(ApiModelWorker):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn

View File

@ -1,5 +1,7 @@
import json
import sys
from fastchat.conversation import Conversation
from configs import TEMPERATURE
from http import HTTPStatus
from typing import List, Literal, Dict
@ -68,15 +70,6 @@ class QwenWorker(ApiModelWorker):
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.api_key = config.get("api_key")
self.version = version
@ -108,6 +101,17 @@ class QwenWorker(ApiModelWorker):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
@ -38,16 +39,6 @@ class XingHuoWorker(ApiModelWorker):
kwargs.setdefault("context_len", 8192)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
# TODO: 当前每次对话都要重新连接websocket确认是否可以保持连接
@ -86,6 +77,17 @@ class XingHuoWorker(ApiModelWorker):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
@ -23,16 +24,6 @@ class ChatGLMWorker(ApiModelWorker):
super().__init__(**kwargs)
self.version = version
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["Human", "Assistant"],
sep="\n###",
stop_str="###",
)
def generate_stream_gate(self, params):
# TODO: 维护request_id
import zhipuai
@ -59,6 +50,17 @@ class ChatGLMWorker(ApiModelWorker):
print("embedding")
# print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# 这里的是chatglm api的模板其它API的conv_template需要定制
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["Human", "Assistant"],
sep="\n###",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn

View File

@ -477,7 +477,7 @@ def fschat_controller_address() -> str:
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := get_model_worker_config(model_name):
if model := get_model_worker_config(model_name): # TODO: depends fastchat
host = model["host"]
if host == "0.0.0.0":
host = "127.0.0.1"

View File

@ -423,13 +423,13 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
uvicorn.run(app, host=host, port=port)
def run_api_server(started_event: mp.Event = None):
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
from server.api import create_app
import uvicorn
from server.utils import set_httpx_config
set_httpx_config()
app = create_app()
app = create_app(run_mode=run_mode)
_set_app_event(app, started_event)
host = API_SERVER["host"]
@ -438,21 +438,27 @@ def run_api_server(started_event: mp.Event = None):
uvicorn.run(app, host=host, port=port)
def run_webui(started_event: mp.Event = None):
def run_webui(started_event: mp.Event = None, run_mode: str = None):
from server.utils import set_httpx_config
set_httpx_config()
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
])
cmd = ["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
]
if run_mode == "lite":
cmd += [
"--",
"lite",
]
p = subprocess.Popen(cmd)
started_event.set()
p.wait()
@ -535,6 +541,13 @@ def parse_args() -> argparse.ArgumentParser:
help="减少fastchat服务log信息",
dest="quiet",
)
parser.add_argument(
"-i",
"--lite",
action="store_true",
help="以Lite模式运行仅支持在线API的LLM对话、搜索引擎对话",
dest="lite",
)
args = parser.parse_args()
return args, parser
@ -596,6 +609,7 @@ async def start_main_server():
mp.set_start_method("spawn")
manager = mp.Manager()
run_mode = None
queue = manager.Queue()
args, parser = parse_args()
@ -621,6 +635,10 @@ async def start_main_server():
args.api = False
args.webui = False
if args.lite:
args.model_worker = False
run_mode = "lite"
dump_server_info(args=args)
if len(sys.argv) > 1:
@ -698,7 +716,7 @@ async def start_main_server():
process = Process(
target=run_api_server,
name=f"API Server",
kwargs=dict(started_event=api_started),
kwargs=dict(started_event=api_started, run_mode=run_mode),
daemon=True,
)
processes["api"] = process
@ -708,7 +726,7 @@ async def start_main_server():
process = Process(
target=run_webui,
name=f"WEBUI Server",
kwargs=dict(started_event=webui_started),
kwargs=dict(started_event=webui_started, run_mode=run_mode),
daemon=True,
)
processes["webui"] = process

View File

@ -0,0 +1,47 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from webui_pages.utils import ApiRequest
import pytest
from pprint import pprint
from typing import List
api = ApiRequest()
def test_get_default_llm():
llm = api.get_default_llm_model()
print(llm)
assert isinstance(llm, tuple)
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
def test_server_configs():
configs = api.get_server_configs()
pprint(configs, depth=2)
assert isinstance(configs, dict)
assert len(configs) > 0
def test_list_search_engines():
engines = api.list_search_engines()
pprint(engines)
assert isinstance(engines, list)
assert len(engines) > 0
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
def test_get_prompt_template(type):
print(f"prompt template for: {type}")
template = api.get_prompt_template(type=type)
print(template)
assert isinstance(template, str)
assert len(template) > 0

View File

@ -0,0 +1,81 @@
import requests
import json
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs import BING_SUBSCRIPTION_KEY
from server.utils import api_address
from pprint import pprint
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
api_base_url = api_address()
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
def knowledge_chat(api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": "如何提问以获得高质量答案",
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
result = []
response = requests.post(url, headers=headers, json=data, stream=True)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
result.append(data)
return result
def test_thread():
threads = []
times = []
pool = ThreadPoolExecutor()
start = time.time()
for i in range(10):
t = pool.submit(knowledge_chat)
threads.append(t)
for r in as_completed(threads):
end = time.time()
times.append(end - start)
print("\nResult:\n")
pprint(r.result())
print("\nTime used:\n")
for x in times:
print(f"{x}")

View File

@ -1,8 +1,9 @@
import streamlit as st
from webui_pages.utils import *
from streamlit_option_menu import option_menu
from webui_pages import *
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
import os
import sys
from configs import VERSION
from server.utils import api_address
@ -10,6 +11,8 @@ from server.utils import api_address
api = ApiRequest(base_url=api_address())
if __name__ == "__main__":
is_lite = "lite" in sys.argv
st.set_page_config(
"Langchain-Chatchat WebUI",
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
@ -26,11 +29,15 @@ if __name__ == "__main__":
"icon": "chat",
"func": dialogue_page,
},
"知识库管理": {
"icon": "hdd-stack",
"func": knowledge_base_page,
},
}
if not is_lite:
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
pages.update({
"知识库管理": {
"icon": "hdd-stack",
"func": knowledge_base_page,
},
})
with st.sidebar:
st.image(
@ -57,4 +64,4 @@ if __name__ == "__main__":
)
if selected_page in pages:
pages[selected_page]["func"](api)
pages[selected_page]["func"](api=api, is_lite=is_lite)

View File

@ -1,3 +0,0 @@
from .dialogue import dialogue_page, chat_box
from .knowledge_base import knowledge_base_page
from .model_config import model_config_page

View File

@ -1 +0,0 @@
from .dialogue import dialogue_page, chat_box

View File

@ -3,10 +3,11 @@ from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
import os
from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
from typing import List, Dict
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
@ -33,27 +34,9 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
return chat_box.filter_history(history_len=history_len, filter=filter)
def get_default_llm_model(api: ApiRequest) -> (str, bool):
'''
从服务器上获取当前运行的LLM模型如果本机配置的LLM_MODEL属于本地模型且在其中则优先返回
返回类型为model_name, is_local_model
'''
running_models = api.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return list(running_models)[0], False
def dialogue_page(api: ApiRequest):
def dialogue_page(api: ApiRequest, is_lite: bool = False):
if not chat_box.chat_inited:
default_model = get_default_llm_model(api)[0]
default_model = api.get_default_llm_model()[0]
st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前运行的模型`{default_model}`, 您可以开始提问了."
@ -70,13 +53,19 @@ def dialogue_page(api: ApiRequest):
text = f"{text} 当前知识库: `{cur_kb}`。"
st.toast(text)
if is_lite:
dialogue_modes = ["LLM 对话",
"搜索引擎问答",
]
else:
dialogue_modes = ["LLM 对话",
"知识库问答",
"搜索引擎问答",
"自定义Agent问答",
]
dialogue_mode = st.selectbox("请选择对话模式:",
["LLM 对话",
"知识库问答",
"搜索引擎问答",
"自定义Agent问答",
],
index=3,
dialogue_modes,
index=0,
on_change=on_mode_change,
key="dialogue_mode",
)
@ -107,7 +96,7 @@ def dialogue_page(api: ApiRequest):
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
available_models.append(k)
llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
@ -116,9 +105,10 @@ def dialogue_page(api: ApiRequest):
key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not llm_model in config_models.get("online", {})
and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models):
and not is_lite
and not llm_model in config_models.get("online", {})
and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model)
@ -128,12 +118,18 @@ def dialogue_page(api: ApiRequest):
st.success(msg)
st.session_state["prev_llm_model"] = llm_model
index_prompt = {
"LLM 对话": "llm_chat",
"自定义Agent问答": "agent_chat",
"搜索引擎问答": "search_engine_chat",
"知识库问答": "knowledge_base_chat",
}
if is_lite:
index_prompt = {
"LLM 对话": "llm_chat",
"搜索引擎问答": "search_engine_chat",
}
else:
index_prompt = {
"LLM 对话": "llm_chat",
"自定义Agent问答": "agent_chat",
"搜索引擎问答": "search_engine_chat",
"知识库问答": "knowledge_base_chat",
}
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
prompt_template_name = prompt_templates_kb_list[0]
if "prompt_template_select" not in st.session_state:
@ -284,7 +280,8 @@ def dialogue_page(api: ApiRequest):
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
temperature=temperature,
split_result=se_top_k>1):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):

View File

@ -20,12 +20,11 @@ from configs import (
logger, log_verbose,
)
import httpx
from server.chat.openai_chat import OpenAiChatMsgIn
import contextlib
import json
import os
from io import BytesIO
from server.utils import run_async, set_httpx_config, api_address, get_httpx_client
from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint
@ -213,7 +212,7 @@ class ApiRequest:
if log_verbose:
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg}
return {"code": 500, "msg": msg, "data": None}
if value_func is None:
value_func = (lambda r: r)
@ -233,9 +232,26 @@ class ApiRequest:
return value_func(response)
# 服务器信息
def get_server_configs(self, **kwargs):
def get_server_configs(self, **kwargs) -> Dict:
response = self.post("/server/configs", **kwargs)
return self._get_response_value(response, lambda r: r.json())
return self._get_response_value(response, as_json=True)
def list_search_engines(self, **kwargs) -> List:
response = self.post("/server/list_search_engines", **kwargs)
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
def get_prompt_template(
self,
type: str = "llm_chat",
name: str = "default",
**kwargs,
) -> str:
data = {
"type": type,
"name": name,
}
response = self.post("/server/get_prompt_template", json=data, **kwargs)
return self._get_response_value(response, value_func=lambda r: r.text)
# 对话相关操作
@ -251,16 +267,14 @@ class ApiRequest:
'''
对应api.py/chat/fastchat接口
'''
msg = OpenAiChatMsgIn(**{
data = {
"messages": messages,
"stream": stream,
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs,
})
}
data = msg.dict(exclude_unset=True, exclude_none=True)
print(f"received input message:")
pprint(data)
@ -268,6 +282,7 @@ class ApiRequest:
"/chat/fastchat",
json=data,
stream=True,
**kwargs,
)
return self._httpx_stream2generator(response)
@ -380,6 +395,7 @@ class ApiRequest:
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
split_result: bool = False,
):
'''
对应api.py/chat/search_engine_chat接口
@ -394,6 +410,7 @@ class ApiRequest:
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
"split_result": split_result,
}
print(f"received input message:")
@ -659,6 +676,43 @@ class ApiRequest:
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
def get_default_llm_model(self) -> (str, bool):
'''
从服务器上获取当前运行的LLM模型如果本机配置的LLM_MODEL属于本地模型且在其中则优先返回
返回类型为model_name, is_local_model
'''
def ret_sync():
running_models = self.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return list(running_models)[0], False
async def ret_async():
running_models = await self.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return list(running_models)[0], False
if self._use_async:
return ret_async()
else:
return ret_sync()
def list_config_models(self) -> Dict[str, List[str]]:
'''
获取服务器configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...}