支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话 (#1860)

* move get_default_llm_model from webui to ApiRequest

增加API接口及其测试用例:
- /server/get_prompt_template: 获取服务器配置的 prompt 模板
- 增加知识库多线程访问测试用例

支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话

* fix bug in server.api

---------

Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
This commit is contained in:
liunux4odoo 2023-10-25 08:30:23 +08:00 committed by GitHub
parent be67ea43d8
commit 03e55e11c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 541 additions and 227 deletions

60
requirements_lite.txt Normal file
View File

@ -0,0 +1,60 @@
langchain>=0.0.319
fschat>=0.2.31
openai
# sentence_transformers
# transformers>=4.33.0
# torch>=2.0.1
# torchvision
# torchaudio
fastapi>=0.103.1
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
pydantic~=1.10.11
# unstructured[all-docs]>=0.10.4
# python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
# faiss-cpu
# accelerate
# spacy
# PyMuPDF==1.22.5
# rapidocr_onnxruntime>=1.3.2
requests
pathlib
pytest
# scikit-learn
# numexpr
# vllm==0.1.7; sys_platform == "linux"
# online api libs
zhipuai
dashscope>=1.10.0 # qwen
# qianfan
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2
# pgvector
numpy~=1.24.4
pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.9
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog
tqdm
websockets
# tiktoken
einops
# scipy
# transformers_stream_generator==0.0.4
# search engine libs
duckduckgo-search
metaphor-python
strsimpy
markdownify

View File

@ -1,5 +1,5 @@
import json import json
from server.chat import search_engine_chat from server.chat.search_engine_chat import search_engine_chat
from configs import VECTOR_SEARCH_TOP_K from configs import VECTOR_SEARCH_TOP_K
import asyncio import asyncio
from server.agent import model_container from server.agent import model_container

View File

@ -9,19 +9,19 @@ from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN from configs.server_config import OPEN_CROSS_DOMAIN
import argparse import argparse
import uvicorn import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from server.chat import (completion,chat, knowledge_base_chat, openai_chat, from server.chat.chat import chat
search_engine_chat, agent_chat) from server.chat.openai_chat import openai_chat
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.chat.search_engine_chat import search_engine_chat
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, from server.chat.completion import completion
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore, update_info)
from server.llm_api import (list_running_models, list_config_models, from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model, change_llm_model, stop_llm_model,
get_model_config, list_search_engines) get_model_config, list_search_engines)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
from typing import List get_server_configs, get_prompt_template)
from typing import List, Literal
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -30,7 +30,7 @@ async def document():
return RedirectResponse(url="/docs") return RedirectResponse(url="/docs")
def create_app(): def create_app(run_mode: str = None):
app = FastAPI( app = FastAPI(
title="Langchain-Chatchat API Server", title="Langchain-Chatchat API Server",
version=VERSION version=VERSION
@ -47,7 +47,13 @@ def create_app():
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
mount_basic_routes(app)
if run_mode != "lite":
mount_knowledge_routes(app)
return app
def mount_basic_routes(app: FastAPI):
app.get("/", app.get("/",
response_model=BaseResponse, response_model=BaseResponse,
summary="swagger 文档")(document) summary="swagger 文档")(document)
@ -65,14 +71,69 @@ def create_app():
tags=["Chat"], tags=["Chat"],
summary="与llm模型对话(通过LLMChain)")(chat) summary="与llm模型对话(通过LLMChain)")(chat)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
app.post("/chat/search_engine_chat", app.post("/chat/search_engine_chat",
tags=["Chat"], tags=["Chat"],
summary="与搜索引擎对话")(search_engine_chat) summary="与搜索引擎对话")(search_engine_chat)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/get_model_config",
tags=["LLM Model Management"],
summary="获取模型配置(合并后)",
)(get_model_config)
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)(stop_llm_model)
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)(change_llm_model)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],
summary="获取服务器原始配置信息",
)(get_server_configs)
app.post("/server/list_search_engines",
tags=["Server State"],
summary="获取服务器支持的搜索引擎",
)(list_search_engines)
@app.post("/server/get_prompt_template",
tags=["Server State"],
summary="获取服务区配置的 prompt 模板")
def get_server_prompt_template(
type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型可选值llm_chatknowledge_base_chatsearch_engine_chatagent_chat"),
name: str = Body("default", description="模板名称"),
) -> str:
return get_prompt_template(type=type, name=name)
def mount_knowledge_routes(app: FastAPI):
from server.chat.knowledge_base_chat import knowledge_base_chat
from server.chat.agent_chat import agent_chat
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore, update_info)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
app.post("/chat/agent_chat", app.post("/chat/agent_chat",
tags=["Chat"], tags=["Chat"],
summary="与agent对话")(agent_chat) summary="与agent对话")(agent_chat)
@ -139,48 +200,6 @@ def create_app():
summary="根据content中文档重建向量库流式输出处理进度。" summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store) )(recreate_vector_store)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/get_model_config",
tags=["LLM Model Management"],
summary="获取模型配置(合并后)",
)(get_model_config)
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)(stop_llm_model)
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)(change_llm_model)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],
summary="获取服务器原始配置信息",
)(get_server_configs)
app.post("/server/list_search_engines",
tags=["Server State"],
summary="获取服务器支持的搜索引擎",
)(list_search_engines)
return app
app = create_app()
def run_api(host, port, **kwargs): def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
@ -205,6 +224,10 @@ if __name__ == "__main__":
# 初始化消息 # 初始化消息
args = parser.parse_args() args = parser.parse_args()
args_dict = vars(args) args_dict = vars(args)
app = create_app()
mount_knowledge_routes(app)
run_api(host=args.host, run_api(host=args.host,
port=args.port, port=args.port,
ssl_keyfile=args.ssl_keyfile, ssl_keyfile=args.ssl_keyfile,

View File

@ -1,6 +0,0 @@
from .chat import chat
from .completion import completion
from .knowledge_base_chat import knowledge_base_chat
from .openai_chat import openai_chat
from .search_engine_chat import search_engine_chat
from .agent_chat import agent_chat

View File

@ -1,4 +1,5 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE, LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE) TEXT_SPLITTER_NAME, OVERLAP_SIZE)
@ -12,13 +13,16 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
import asyncio import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Optional, Dict from typing import List, Optional, Dict
from server.chat.utils import History from server.chat.utils import History
from langchain.docstore.document import Document from langchain.docstore.document import Document
import json import json
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K): def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY): if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env info is not found", "title": "env info is not found",
@ -28,7 +32,7 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
return search.results(text, result_len) return search.results(text, result_len)
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K): def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
search = DuckDuckGoSearchAPIWrapper() search = DuckDuckGoSearchAPIWrapper()
return search.results(text, result_len) return search.results(text, result_len)
@ -36,40 +40,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
def metaphor_search( def metaphor_search(
text: str, text: str,
result_len: int = SEARCH_ENGINE_TOP_K, result_len: int = SEARCH_ENGINE_TOP_K,
splitter_name: str = "SpacyTextSplitter", split_result: bool = False,
chunk_size: int = 500, chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]: ) -> List[Dict]:
from metaphor_python import Metaphor from metaphor_python import Metaphor
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from server.knowledge_base.utils import make_text_splitter
if not METAPHOR_API_KEY: if not METAPHOR_API_KEY:
return [] return []
client = Metaphor(METAPHOR_API_KEY) client = Metaphor(METAPHOR_API_KEY)
search = client.search(text, num_results=result_len, use_autoprompt=True) search = client.search(text, num_results=result_len, use_autoprompt=True)
contents = search.get_contents().contents contents = search.get_contents().contents
for x in contents:
x.extract = markdownify(x.extract)
# metaphor 返回的内容都是长文本,需要分词再检索 # metaphor 返回的内容都是长文本,需要分词再检索
docs = [Document(page_content=x.extract, if split_result:
metadata={"link": x.url, "title": x.title}) docs = [Document(page_content=x.extract,
for x in contents] metadata={"link": x.url, "title": x.title})
text_splitter = make_text_splitter(splitter_name=splitter_name, for x in contents]
chunk_size=chunk_size, text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_overlap=chunk_overlap) chunk_size=chunk_size,
splitted_docs = text_splitter.split_documents(docs) chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len: # 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
vs = memo_faiss_pool.new_vector_store() if len(splitted_docs) > result_len:
vs.add_documents(splitted_docs) normal = NormalizedLevenshtein()
splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0) for x in splitted_docs:
x.metadata["score"] = normal.similarity(text, x.page_content)
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
else:
docs = [{"snippet": x.extract,
"link": x.url,
"title": x.title}
for x in contents]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
return docs return docs
@ -93,9 +106,10 @@ async def lookup_search_engine(
query: str, query: str,
search_engine_name: str, search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K, top_k: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
): ):
search_engine = SEARCH_ENGINES[search_engine_name] search_engine = SEARCH_ENGINES[search_engine_name]
results = await run_in_threadpool(search_engine, query, result_len=top_k) results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
docs = search_result2docs(results) docs = search_result2docs(results)
return docs return docs
@ -116,6 +130,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
): ):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -140,7 +155,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
callbacks=[callback], callbacks=[callback],
) )
docs = await lookup_search_engine(query, search_engine_name, top_k) docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
prompt_template = get_prompt_template("search_engine_chat", prompt_name) prompt_template = get_prompt_template("search_engine_chat", prompt_name)

View File

@ -4,6 +4,8 @@
import json import json
import time import time
import hashlib import hashlib
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config, get_httpx_client from server.utils import get_model_worker_config, get_httpx_client
from fastchat import conversation as conv from fastchat import conversation as conv
@ -78,16 +80,6 @@ class BaiChuanWorker(ApiModelWorker):
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs) super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
config = self.get_config() config = self.get_config()
self.version = config.get("version",version) self.version = config.get("version",version)
self.api_key = config.get("api_key") self.api_key = config.get("api_key")
@ -127,6 +119,18 @@ class BaiChuanWorker(ApiModelWorker):
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
from server.utils import MakeFastAPIOffline from server.utils import MakeFastAPIOffline

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from configs.basic_config import LOG_PATH from configs.basic_config import LOG_PATH
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
@ -61,6 +62,9 @@ class ApiModelWorker(BaseModelWorker):
print("embedding") print("embedding")
# print(params) # print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
raise NotImplementedError
# help methods # help methods
def get_config(self): def get_config(self):
from server.utils import get_model_worker_config from server.utils import get_model_worker_config

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE from configs.model_config import TEMPERATURE
from fastchat import conversation as conv from fastchat import conversation as conv
@ -74,15 +75,6 @@ class FangZhouWorker(ApiModelWorker):
self.api_key = config.get("api_key") self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key") self.secret_key = config.get("secret_key")
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
super().generate_stream_gate(params) super().generate_stream_gate(params)
@ -107,6 +99,16 @@ class FangZhouWorker(ApiModelWorker):
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
@ -22,16 +23,6 @@ class MiniMaxWorker(ApiModelWorker):
kwargs.setdefault("context_len", 16384) kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs) super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
def prompt_to_messages(self, prompt: str) -> List[Dict]: def prompt_to_messages(self, prompt: str) -> List[Dict]:
result = super().prompt_to_messages(prompt) result = super().prompt_to_messages(prompt)
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result] messages = [{"sender_type": x["role"], "text": x["content"]} for x in result]
@ -86,6 +77,17 @@ class MiniMaxWorker(ApiModelWorker):
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE from configs.model_config import TEMPERATURE
from fastchat import conversation as conv from fastchat import conversation as conv
@ -120,16 +121,6 @@ class QianFanWorker(ApiModelWorker):
kwargs.setdefault("context_len", 16384) kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs) super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
config = self.get_config() config = self.get_config()
self.version = version self.version = version
self.api_key = config.get("api_key") self.api_key = config.get("api_key")
@ -162,6 +153,17 @@ class QianFanWorker(ApiModelWorker):
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@ -1,5 +1,7 @@
import json import json
import sys import sys
from fastchat.conversation import Conversation
from configs import TEMPERATURE from configs import TEMPERATURE
from http import HTTPStatus from http import HTTPStatus
from typing import List, Literal, Dict from typing import List, Literal, Dict
@ -68,15 +70,6 @@ class QwenWorker(ApiModelWorker):
kwargs.setdefault("context_len", 16384) kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs) super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
config = self.get_config() config = self.get_config()
self.api_key = config.get("api_key") self.api_key = config.get("api_key")
self.version = version self.version = version
@ -108,6 +101,17 @@ class QwenWorker(ApiModelWorker):
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
@ -38,16 +39,6 @@ class XingHuoWorker(ApiModelWorker):
kwargs.setdefault("context_len", 8192) kwargs.setdefault("context_len", 8192)
super().__init__(**kwargs) super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
# TODO: 当前每次对话都要重新连接websocket确认是否可以保持连接 # TODO: 当前每次对话都要重新连接websocket确认是否可以保持连接
@ -86,6 +77,17 @@ class XingHuoWorker(ApiModelWorker):
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@ -1,3 +1,4 @@
from fastchat.conversation import Conversation
from server.model_workers.base import ApiModelWorker from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
@ -23,16 +24,6 @@ class ChatGLMWorker(ApiModelWorker):
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["Human", "Assistant"],
sep="\n###",
stop_str="###",
)
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
# TODO: 维护request_id # TODO: 维护request_id
import zhipuai import zhipuai
@ -59,6 +50,17 @@ class ChatGLMWorker(ApiModelWorker):
print("embedding") print("embedding")
# print(params) # print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# 这里的是chatglm api的模板其它API的conv_template需要定制
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["Human", "Assistant"],
sep="\n###",
stop_str="###",
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@ -477,7 +477,7 @@ def fschat_controller_address() -> str:
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := get_model_worker_config(model_name): if model := get_model_worker_config(model_name): # TODO: depends fastchat
host = model["host"] host = model["host"]
if host == "0.0.0.0": if host == "0.0.0.0":
host = "127.0.0.1" host = "127.0.0.1"

View File

@ -423,13 +423,13 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
def run_api_server(started_event: mp.Event = None): def run_api_server(started_event: mp.Event = None, run_mode: str = None):
from server.api import create_app from server.api import create_app
import uvicorn import uvicorn
from server.utils import set_httpx_config from server.utils import set_httpx_config
set_httpx_config() set_httpx_config()
app = create_app() app = create_app(run_mode=run_mode)
_set_app_event(app, started_event) _set_app_event(app, started_event)
host = API_SERVER["host"] host = API_SERVER["host"]
@ -438,21 +438,27 @@ def run_api_server(started_event: mp.Event = None):
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
def run_webui(started_event: mp.Event = None): def run_webui(started_event: mp.Event = None, run_mode: str = None):
from server.utils import set_httpx_config from server.utils import set_httpx_config
set_httpx_config() set_httpx_config()
host = WEBUI_SERVER["host"] host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"] port = WEBUI_SERVER["port"]
p = subprocess.Popen(["streamlit", "run", "webui.py", cmd = ["streamlit", "run", "webui.py",
"--server.address", host, "--server.address", host,
"--server.port", str(port), "--server.port", str(port),
"--theme.base", "light", "--theme.base", "light",
"--theme.primaryColor", "#165dff", "--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5", "--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000", "--theme.textColor", "#000000",
]) ]
if run_mode == "lite":
cmd += [
"--",
"lite",
]
p = subprocess.Popen(cmd)
started_event.set() started_event.set()
p.wait() p.wait()
@ -535,6 +541,13 @@ def parse_args() -> argparse.ArgumentParser:
help="减少fastchat服务log信息", help="减少fastchat服务log信息",
dest="quiet", dest="quiet",
) )
parser.add_argument(
"-i",
"--lite",
action="store_true",
help="以Lite模式运行仅支持在线API的LLM对话、搜索引擎对话",
dest="lite",
)
args = parser.parse_args() args = parser.parse_args()
return args, parser return args, parser
@ -596,6 +609,7 @@ async def start_main_server():
mp.set_start_method("spawn") mp.set_start_method("spawn")
manager = mp.Manager() manager = mp.Manager()
run_mode = None
queue = manager.Queue() queue = manager.Queue()
args, parser = parse_args() args, parser = parse_args()
@ -621,6 +635,10 @@ async def start_main_server():
args.api = False args.api = False
args.webui = False args.webui = False
if args.lite:
args.model_worker = False
run_mode = "lite"
dump_server_info(args=args) dump_server_info(args=args)
if len(sys.argv) > 1: if len(sys.argv) > 1:
@ -698,7 +716,7 @@ async def start_main_server():
process = Process( process = Process(
target=run_api_server, target=run_api_server,
name=f"API Server", name=f"API Server",
kwargs=dict(started_event=api_started), kwargs=dict(started_event=api_started, run_mode=run_mode),
daemon=True, daemon=True,
) )
processes["api"] = process processes["api"] = process
@ -708,7 +726,7 @@ async def start_main_server():
process = Process( process = Process(
target=run_webui, target=run_webui,
name=f"WEBUI Server", name=f"WEBUI Server",
kwargs=dict(started_event=webui_started), kwargs=dict(started_event=webui_started, run_mode=run_mode),
daemon=True, daemon=True,
) )
processes["webui"] = process processes["webui"] = process

View File

@ -0,0 +1,47 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from webui_pages.utils import ApiRequest
import pytest
from pprint import pprint
from typing import List
api = ApiRequest()
def test_get_default_llm():
llm = api.get_default_llm_model()
print(llm)
assert isinstance(llm, tuple)
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
def test_server_configs():
configs = api.get_server_configs()
pprint(configs, depth=2)
assert isinstance(configs, dict)
assert len(configs) > 0
def test_list_search_engines():
engines = api.list_search_engines()
pprint(engines)
assert isinstance(engines, list)
assert len(engines) > 0
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
def test_get_prompt_template(type):
print(f"prompt template for: {type}")
template = api.get_prompt_template(type=type)
print(template)
assert isinstance(template, str)
assert len(template) > 0

View File

@ -0,0 +1,81 @@
import requests
import json
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs import BING_SUBSCRIPTION_KEY
from server.utils import api_address
from pprint import pprint
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
api_base_url = api_address()
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
def knowledge_chat(api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": "如何提问以获得高质量答案",
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
result = []
response = requests.post(url, headers=headers, json=data, stream=True)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
result.append(data)
return result
def test_thread():
threads = []
times = []
pool = ThreadPoolExecutor()
start = time.time()
for i in range(10):
t = pool.submit(knowledge_chat)
threads.append(t)
for r in as_completed(threads):
end = time.time()
times.append(end - start)
print("\nResult:\n")
pprint(r.result())
print("\nTime used:\n")
for x in times:
print(f"{x}")

View File

@ -1,8 +1,9 @@
import streamlit as st import streamlit as st
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_option_menu import option_menu from streamlit_option_menu import option_menu
from webui_pages import * from webui_pages.dialogue.dialogue import dialogue_page, chat_box
import os import os
import sys
from configs import VERSION from configs import VERSION
from server.utils import api_address from server.utils import api_address
@ -10,6 +11,8 @@ from server.utils import api_address
api = ApiRequest(base_url=api_address()) api = ApiRequest(base_url=api_address())
if __name__ == "__main__": if __name__ == "__main__":
is_lite = "lite" in sys.argv
st.set_page_config( st.set_page_config(
"Langchain-Chatchat WebUI", "Langchain-Chatchat WebUI",
os.path.join("img", "chatchat_icon_blue_square_v2.png"), os.path.join("img", "chatchat_icon_blue_square_v2.png"),
@ -26,11 +29,15 @@ if __name__ == "__main__":
"icon": "chat", "icon": "chat",
"func": dialogue_page, "func": dialogue_page,
}, },
"知识库管理": {
"icon": "hdd-stack",
"func": knowledge_base_page,
},
} }
if not is_lite:
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
pages.update({
"知识库管理": {
"icon": "hdd-stack",
"func": knowledge_base_page,
},
})
with st.sidebar: with st.sidebar:
st.image( st.image(
@ -57,4 +64,4 @@ if __name__ == "__main__":
) )
if selected_page in pages: if selected_page in pages:
pages[selected_page]["func"](api) pages[selected_page]["func"](api=api, is_lite=is_lite)

View File

@ -1,3 +0,0 @@
from .dialogue import dialogue_page, chat_box
from .knowledge_base import knowledge_base_page
from .model_config import model_config_page

View File

@ -1 +0,0 @@
from .dialogue import dialogue_page, chat_box

View File

@ -3,10 +3,11 @@ from webui_pages.utils import *
from streamlit_chatbox import * from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
import os import os
from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL) DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
"img", "img",
@ -33,27 +34,9 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
return chat_box.filter_history(history_len=history_len, filter=filter) return chat_box.filter_history(history_len=history_len, filter=filter)
def get_default_llm_model(api: ApiRequest) -> (str, bool): def dialogue_page(api: ApiRequest, is_lite: bool = False):
'''
从服务器上获取当前运行的LLM模型如果本机配置的LLM_MODEL属于本地模型且在其中则优先返回
返回类型为model_name, is_local_model
'''
running_models = api.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return list(running_models)[0], False
def dialogue_page(api: ApiRequest):
if not chat_box.chat_inited: if not chat_box.chat_inited:
default_model = get_default_llm_model(api)[0] default_model = api.get_default_llm_model()[0]
st.toast( st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n" f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前运行的模型`{default_model}`, 您可以开始提问了." f"当前运行的模型`{default_model}`, 您可以开始提问了."
@ -70,13 +53,19 @@ def dialogue_page(api: ApiRequest):
text = f"{text} 当前知识库: `{cur_kb}`。" text = f"{text} 当前知识库: `{cur_kb}`。"
st.toast(text) st.toast(text)
if is_lite:
dialogue_modes = ["LLM 对话",
"搜索引擎问答",
]
else:
dialogue_modes = ["LLM 对话",
"知识库问答",
"搜索引擎问答",
"自定义Agent问答",
]
dialogue_mode = st.selectbox("请选择对话模式:", dialogue_mode = st.selectbox("请选择对话模式:",
["LLM 对话", dialogue_modes,
"知识库问答", index=0,
"搜索引擎问答",
"自定义Agent问答",
],
index=3,
on_change=on_mode_change, on_change=on_mode_change,
key="dialogue_mode", key="dialogue_mode",
) )
@ -107,7 +96,7 @@ def dialogue_page(api: ApiRequest):
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型 for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
available_models.append(k) available_models.append(k)
llm_models = running_models + available_models llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0])) index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
llm_model = st.selectbox("选择LLM模型", llm_model = st.selectbox("选择LLM模型",
llm_models, llm_models,
index, index,
@ -116,9 +105,10 @@ def dialogue_page(api: ApiRequest):
key="llm_model", key="llm_model",
) )
if (st.session_state.get("prev_llm_model") != llm_model if (st.session_state.get("prev_llm_model") != llm_model
and not llm_model in config_models.get("online", {}) and not is_lite
and not llm_model in config_models.get("langchain", {}) and not llm_model in config_models.get("online", {})
and llm_model not in running_models): and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model") prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model) r = api.change_llm_model(prev_model, llm_model)
@ -128,12 +118,18 @@ def dialogue_page(api: ApiRequest):
st.success(msg) st.success(msg)
st.session_state["prev_llm_model"] = llm_model st.session_state["prev_llm_model"] = llm_model
index_prompt = { if is_lite:
"LLM 对话": "llm_chat", index_prompt = {
"自定义Agent问答": "agent_chat", "LLM 对话": "llm_chat",
"搜索引擎问答": "search_engine_chat", "搜索引擎问答": "search_engine_chat",
"知识库问答": "knowledge_base_chat", }
} else:
index_prompt = {
"LLM 对话": "llm_chat",
"自定义Agent问答": "agent_chat",
"搜索引擎问答": "search_engine_chat",
"知识库问答": "knowledge_base_chat",
}
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys()) prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
prompt_template_name = prompt_templates_kb_list[0] prompt_template_name = prompt_templates_kb_list[0]
if "prompt_template_select" not in st.session_state: if "prompt_template_select" not in st.session_state:
@ -284,7 +280,8 @@ def dialogue_page(api: ApiRequest):
history=history, history=history,
model=llm_model, model=llm_model,
prompt_name=prompt_template_name, prompt_name=prompt_template_name,
temperature=temperature): temperature=temperature,
split_result=se_top_k>1):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
elif chunk := d.get("answer"): elif chunk := d.get("answer"):

View File

@ -20,12 +20,11 @@ from configs import (
logger, log_verbose, logger, log_verbose,
) )
import httpx import httpx
from server.chat.openai_chat import OpenAiChatMsgIn
import contextlib import contextlib
import json import json
import os import os
from io import BytesIO from io import BytesIO
from server.utils import run_async, set_httpx_config, api_address, get_httpx_client from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint from pprint import pprint
@ -213,7 +212,7 @@ class ApiRequest:
if log_verbose: if log_verbose:
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg} return {"code": 500, "msg": msg, "data": None}
if value_func is None: if value_func is None:
value_func = (lambda r: r) value_func = (lambda r: r)
@ -233,9 +232,26 @@ class ApiRequest:
return value_func(response) return value_func(response)
# 服务器信息 # 服务器信息
def get_server_configs(self, **kwargs): def get_server_configs(self, **kwargs) -> Dict:
response = self.post("/server/configs", **kwargs) response = self.post("/server/configs", **kwargs)
return self._get_response_value(response, lambda r: r.json()) return self._get_response_value(response, as_json=True)
def list_search_engines(self, **kwargs) -> List:
response = self.post("/server/list_search_engines", **kwargs)
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
def get_prompt_template(
self,
type: str = "llm_chat",
name: str = "default",
**kwargs,
) -> str:
data = {
"type": type,
"name": name,
}
response = self.post("/server/get_prompt_template", json=data, **kwargs)
return self._get_response_value(response, value_func=lambda r: r.text)
# 对话相关操作 # 对话相关操作
@ -251,16 +267,14 @@ class ApiRequest:
''' '''
对应api.py/chat/fastchat接口 对应api.py/chat/fastchat接口
''' '''
msg = OpenAiChatMsgIn(**{ data = {
"messages": messages, "messages": messages,
"stream": stream, "stream": stream,
"model": model, "model": model,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
**kwargs, }
})
data = msg.dict(exclude_unset=True, exclude_none=True)
print(f"received input message:") print(f"received input message:")
pprint(data) pprint(data)
@ -268,6 +282,7 @@ class ApiRequest:
"/chat/fastchat", "/chat/fastchat",
json=data, json=data,
stream=True, stream=True,
**kwargs,
) )
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response)
@ -380,6 +395,7 @@ class ApiRequest:
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
split_result: bool = False,
): ):
''' '''
对应api.py/chat/search_engine_chat接口 对应api.py/chat/search_engine_chat接口
@ -394,6 +410,7 @@ class ApiRequest:
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"prompt_name": prompt_name, "prompt_name": prompt_name,
"split_result": split_result,
} }
print(f"received input message:") print(f"received input message:")
@ -659,6 +676,43 @@ class ApiRequest:
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", [])) return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
def get_default_llm_model(self) -> (str, bool):
'''
从服务器上获取当前运行的LLM模型如果本机配置的LLM_MODEL属于本地模型且在其中则优先返回
返回类型为model_name, is_local_model
'''
def ret_sync():
running_models = self.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return list(running_models)[0], False
async def ret_async():
running_models = await self.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return list(running_models)[0], False
if self._use_async:
return ret_async()
else:
return ret_sync()
def list_config_models(self) -> Dict[str, List[str]]: def list_config_models(self) -> Dict[str, List[str]]:
''' '''
获取服务器configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...} 获取服务器configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...}