mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话 (#1860)
* move get_default_llm_model from webui to ApiRequest 增加API接口及其测试用例: - /server/get_prompt_template: 获取服务器配置的 prompt 模板 - 增加知识库多线程访问测试用例 支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话 * fix bug in server.api --------- Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
This commit is contained in:
parent
be67ea43d8
commit
03e55e11c4
60
requirements_lite.txt
Normal file
60
requirements_lite.txt
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
langchain>=0.0.319
|
||||||
|
fschat>=0.2.31
|
||||||
|
openai
|
||||||
|
# sentence_transformers
|
||||||
|
# transformers>=4.33.0
|
||||||
|
# torch>=2.0.1
|
||||||
|
# torchvision
|
||||||
|
# torchaudio
|
||||||
|
fastapi>=0.103.1
|
||||||
|
nltk~=3.8.1
|
||||||
|
uvicorn~=0.23.1
|
||||||
|
starlette~=0.27.0
|
||||||
|
pydantic~=1.10.11
|
||||||
|
# unstructured[all-docs]>=0.10.4
|
||||||
|
# python-magic-bin; sys_platform == 'win32'
|
||||||
|
SQLAlchemy==2.0.19
|
||||||
|
# faiss-cpu
|
||||||
|
# accelerate
|
||||||
|
# spacy
|
||||||
|
# PyMuPDF==1.22.5
|
||||||
|
# rapidocr_onnxruntime>=1.3.2
|
||||||
|
|
||||||
|
requests
|
||||||
|
pathlib
|
||||||
|
pytest
|
||||||
|
# scikit-learn
|
||||||
|
# numexpr
|
||||||
|
# vllm==0.1.7; sys_platform == "linux"
|
||||||
|
# online api libs
|
||||||
|
zhipuai
|
||||||
|
dashscope>=1.10.0 # qwen
|
||||||
|
# qianfan
|
||||||
|
# volcengine>=1.0.106 # fangzhou
|
||||||
|
|
||||||
|
# uncomment libs if you want to use corresponding vector store
|
||||||
|
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||||
|
# psycopg2
|
||||||
|
# pgvector
|
||||||
|
|
||||||
|
numpy~=1.24.4
|
||||||
|
pandas~=2.0.3
|
||||||
|
streamlit>=1.26.0
|
||||||
|
streamlit-option-menu>=0.3.6
|
||||||
|
streamlit-antd-components>=0.1.11
|
||||||
|
streamlit-chatbox>=1.1.9
|
||||||
|
streamlit-aggrid>=0.3.4.post3
|
||||||
|
httpx~=0.24.1
|
||||||
|
watchdog
|
||||||
|
tqdm
|
||||||
|
websockets
|
||||||
|
# tiktoken
|
||||||
|
einops
|
||||||
|
# scipy
|
||||||
|
# transformers_stream_generator==0.0.4
|
||||||
|
|
||||||
|
# search engine libs
|
||||||
|
duckduckgo-search
|
||||||
|
metaphor-python
|
||||||
|
strsimpy
|
||||||
|
markdownify
|
||||||
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from server.chat import search_engine_chat
|
from server.chat.search_engine_chat import search_engine_chat
|
||||||
from configs import VECTOR_SEARCH_TOP_K
|
from configs import VECTOR_SEARCH_TOP_K
|
||||||
import asyncio
|
import asyncio
|
||||||
from server.agent import model_container
|
from server.agent import model_container
|
||||||
|
|||||||
133
server/api.py
133
server/api.py
@ -9,19 +9,19 @@ from configs.model_config import NLTK_DATA_PATH
|
|||||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||||
import argparse
|
import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from fastapi import Body
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from server.chat import (completion,chat, knowledge_base_chat, openai_chat,
|
from server.chat.chat import chat
|
||||||
search_engine_chat, agent_chat)
|
from server.chat.openai_chat import openai_chat
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.chat.search_engine_chat import search_engine_chat
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.chat.completion import completion
|
||||||
update_docs, download_doc, recreate_vector_store,
|
|
||||||
search_docs, DocumentWithScore, update_info)
|
|
||||||
from server.llm_api import (list_running_models, list_config_models,
|
from server.llm_api import (list_running_models, list_config_models,
|
||||||
change_llm_model, stop_llm_model,
|
change_llm_model, stop_llm_model,
|
||||||
get_model_config, list_search_engines)
|
get_model_config, list_search_engines)
|
||||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs
|
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||||
from typing import List
|
get_server_configs, get_prompt_template)
|
||||||
|
from typing import List, Literal
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ async def document():
|
|||||||
return RedirectResponse(url="/docs")
|
return RedirectResponse(url="/docs")
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app(run_mode: str = None):
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Langchain-Chatchat API Server",
|
title="Langchain-Chatchat API Server",
|
||||||
version=VERSION
|
version=VERSION
|
||||||
@ -47,7 +47,13 @@ def create_app():
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
mount_basic_routes(app)
|
||||||
|
if run_mode != "lite":
|
||||||
|
mount_knowledge_routes(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def mount_basic_routes(app: FastAPI):
|
||||||
app.get("/",
|
app.get("/",
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
summary="swagger 文档")(document)
|
summary="swagger 文档")(document)
|
||||||
@ -65,14 +71,69 @@ def create_app():
|
|||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与llm模型对话(通过LLMChain)")(chat)
|
summary="与llm模型对话(通过LLMChain)")(chat)
|
||||||
|
|
||||||
app.post("/chat/knowledge_base_chat",
|
|
||||||
tags=["Chat"],
|
|
||||||
summary="与知识库对话")(knowledge_base_chat)
|
|
||||||
|
|
||||||
app.post("/chat/search_engine_chat",
|
app.post("/chat/search_engine_chat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与搜索引擎对话")(search_engine_chat)
|
summary="与搜索引擎对话")(search_engine_chat)
|
||||||
|
|
||||||
|
# LLM模型相关接口
|
||||||
|
app.post("/llm_model/list_running_models",
|
||||||
|
tags=["LLM Model Management"],
|
||||||
|
summary="列出当前已加载的模型",
|
||||||
|
)(list_running_models)
|
||||||
|
|
||||||
|
app.post("/llm_model/list_config_models",
|
||||||
|
tags=["LLM Model Management"],
|
||||||
|
summary="列出configs已配置的模型",
|
||||||
|
)(list_config_models)
|
||||||
|
|
||||||
|
app.post("/llm_model/get_model_config",
|
||||||
|
tags=["LLM Model Management"],
|
||||||
|
summary="获取模型配置(合并后)",
|
||||||
|
)(get_model_config)
|
||||||
|
|
||||||
|
app.post("/llm_model/stop",
|
||||||
|
tags=["LLM Model Management"],
|
||||||
|
summary="停止指定的LLM模型(Model Worker)",
|
||||||
|
)(stop_llm_model)
|
||||||
|
|
||||||
|
app.post("/llm_model/change",
|
||||||
|
tags=["LLM Model Management"],
|
||||||
|
summary="切换指定的LLM模型(Model Worker)",
|
||||||
|
)(change_llm_model)
|
||||||
|
|
||||||
|
# 服务器相关接口
|
||||||
|
app.post("/server/configs",
|
||||||
|
tags=["Server State"],
|
||||||
|
summary="获取服务器原始配置信息",
|
||||||
|
)(get_server_configs)
|
||||||
|
|
||||||
|
app.post("/server/list_search_engines",
|
||||||
|
tags=["Server State"],
|
||||||
|
summary="获取服务器支持的搜索引擎",
|
||||||
|
)(list_search_engines)
|
||||||
|
|
||||||
|
@app.post("/server/get_prompt_template",
|
||||||
|
tags=["Server State"],
|
||||||
|
summary="获取服务区配置的 prompt 模板")
|
||||||
|
def get_server_prompt_template(
|
||||||
|
type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat,search_engine_chat,agent_chat"),
|
||||||
|
name: str = Body("default", description="模板名称"),
|
||||||
|
) -> str:
|
||||||
|
return get_prompt_template(type=type, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def mount_knowledge_routes(app: FastAPI):
|
||||||
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
|
from server.chat.agent_chat import agent_chat
|
||||||
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
|
update_docs, download_doc, recreate_vector_store,
|
||||||
|
search_docs, DocumentWithScore, update_info)
|
||||||
|
|
||||||
|
app.post("/chat/knowledge_base_chat",
|
||||||
|
tags=["Chat"],
|
||||||
|
summary="与知识库对话")(knowledge_base_chat)
|
||||||
|
|
||||||
app.post("/chat/agent_chat",
|
app.post("/chat/agent_chat",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与agent对话")(agent_chat)
|
summary="与agent对话")(agent_chat)
|
||||||
@ -139,48 +200,6 @@ def create_app():
|
|||||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||||
)(recreate_vector_store)
|
)(recreate_vector_store)
|
||||||
|
|
||||||
# LLM模型相关接口
|
|
||||||
app.post("/llm_model/list_running_models",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="列出当前已加载的模型",
|
|
||||||
)(list_running_models)
|
|
||||||
|
|
||||||
app.post("/llm_model/list_config_models",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="列出configs已配置的模型",
|
|
||||||
)(list_config_models)
|
|
||||||
|
|
||||||
app.post("/llm_model/get_model_config",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="获取模型配置(合并后)",
|
|
||||||
)(get_model_config)
|
|
||||||
|
|
||||||
app.post("/llm_model/stop",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="停止指定的LLM模型(Model Worker)",
|
|
||||||
)(stop_llm_model)
|
|
||||||
|
|
||||||
app.post("/llm_model/change",
|
|
||||||
tags=["LLM Model Management"],
|
|
||||||
summary="切换指定的LLM模型(Model Worker)",
|
|
||||||
)(change_llm_model)
|
|
||||||
|
|
||||||
# 服务器相关接口
|
|
||||||
app.post("/server/configs",
|
|
||||||
tags=["Server State"],
|
|
||||||
summary="获取服务器原始配置信息",
|
|
||||||
)(get_server_configs)
|
|
||||||
|
|
||||||
app.post("/server/list_search_engines",
|
|
||||||
tags=["Server State"],
|
|
||||||
summary="获取服务器支持的搜索引擎",
|
|
||||||
)(list_search_engines)
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
app = create_app()
|
|
||||||
|
|
||||||
|
|
||||||
def run_api(host, port, **kwargs):
|
def run_api(host, port, **kwargs):
|
||||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||||
@ -205,6 +224,10 @@ if __name__ == "__main__":
|
|||||||
# 初始化消息
|
# 初始化消息
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
mount_knowledge_routes(app)
|
||||||
|
|
||||||
run_api(host=args.host,
|
run_api(host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
ssl_keyfile=args.ssl_keyfile,
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
|
|||||||
@ -1,6 +0,0 @@
|
|||||||
from .chat import chat
|
|
||||||
from .completion import completion
|
|
||||||
from .knowledge_base_chat import knowledge_base_chat
|
|
||||||
from .openai_chat import openai_chat
|
|
||||||
from .search_engine_chat import search_engine_chat
|
|
||||||
from .agent_chat import agent_chat
|
|
||||||
@ -1,4 +1,5 @@
|
|||||||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||||||
|
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||||||
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
||||||
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
||||||
@ -12,13 +13,16 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
import json
|
import json
|
||||||
|
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||||||
|
from markdownify import markdownify
|
||||||
|
|
||||||
|
|
||||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||||
"title": "env info is not found",
|
"title": "env info is not found",
|
||||||
@ -28,7 +32,7 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
|||||||
return search.results(text, result_len)
|
return search.results(text, result_len)
|
||||||
|
|
||||||
|
|
||||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||||||
search = DuckDuckGoSearchAPIWrapper()
|
search = DuckDuckGoSearchAPIWrapper()
|
||||||
return search.results(text, result_len)
|
return search.results(text, result_len)
|
||||||
|
|
||||||
@ -36,40 +40,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
|||||||
def metaphor_search(
|
def metaphor_search(
|
||||||
text: str,
|
text: str,
|
||||||
result_len: int = SEARCH_ENGINE_TOP_K,
|
result_len: int = SEARCH_ENGINE_TOP_K,
|
||||||
splitter_name: str = "SpacyTextSplitter",
|
split_result: bool = False,
|
||||||
chunk_size: int = 500,
|
chunk_size: int = 500,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
from metaphor_python import Metaphor
|
from metaphor_python import Metaphor
|
||||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
|
||||||
from server.knowledge_base.utils import make_text_splitter
|
|
||||||
|
|
||||||
if not METAPHOR_API_KEY:
|
if not METAPHOR_API_KEY:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
client = Metaphor(METAPHOR_API_KEY)
|
client = Metaphor(METAPHOR_API_KEY)
|
||||||
search = client.search(text, num_results=result_len, use_autoprompt=True)
|
search = client.search(text, num_results=result_len, use_autoprompt=True)
|
||||||
contents = search.get_contents().contents
|
contents = search.get_contents().contents
|
||||||
|
for x in contents:
|
||||||
|
x.extract = markdownify(x.extract)
|
||||||
|
|
||||||
# metaphor 返回的内容都是长文本,需要分词再检索
|
# metaphor 返回的内容都是长文本,需要分词再检索
|
||||||
docs = [Document(page_content=x.extract,
|
if split_result:
|
||||||
metadata={"link": x.url, "title": x.title})
|
docs = [Document(page_content=x.extract,
|
||||||
for x in contents]
|
metadata={"link": x.url, "title": x.title})
|
||||||
text_splitter = make_text_splitter(splitter_name=splitter_name,
|
for x in contents]
|
||||||
chunk_size=chunk_size,
|
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
||||||
chunk_overlap=chunk_overlap)
|
chunk_size=chunk_size,
|
||||||
splitted_docs = text_splitter.split_documents(docs)
|
chunk_overlap=chunk_overlap)
|
||||||
|
splitted_docs = text_splitter.split_documents(docs)
|
||||||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
|
||||||
if len(splitted_docs) > result_len:
|
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||||||
vs = memo_faiss_pool.new_vector_store()
|
if len(splitted_docs) > result_len:
|
||||||
vs.add_documents(splitted_docs)
|
normal = NormalizedLevenshtein()
|
||||||
splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0)
|
for x in splitted_docs:
|
||||||
|
x.metadata["score"] = normal.similarity(text, x.page_content)
|
||||||
|
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||||
|
splitted_docs = splitted_docs[:result_len]
|
||||||
|
|
||||||
|
docs = [{"snippet": x.page_content,
|
||||||
|
"link": x.metadata["link"],
|
||||||
|
"title": x.metadata["title"]}
|
||||||
|
for x in splitted_docs]
|
||||||
|
else:
|
||||||
|
docs = [{"snippet": x.extract,
|
||||||
|
"link": x.url,
|
||||||
|
"title": x.title}
|
||||||
|
for x in contents]
|
||||||
|
|
||||||
docs = [{"snippet": x.page_content,
|
|
||||||
"link": x.metadata["link"],
|
|
||||||
"title": x.metadata["title"]}
|
|
||||||
for x in splitted_docs]
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
@ -93,9 +106,10 @@ async def lookup_search_engine(
|
|||||||
query: str,
|
query: str,
|
||||||
search_engine_name: str,
|
search_engine_name: str,
|
||||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||||
|
split_result: bool = False,
|
||||||
):
|
):
|
||||||
search_engine = SEARCH_ENGINES[search_engine_name]
|
search_engine = SEARCH_ENGINES[search_engine_name]
|
||||||
results = await run_in_threadpool(search_engine, query, result_len=top_k)
|
results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
|
||||||
docs = search_result2docs(results)
|
docs = search_result2docs(results)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@ -116,6 +130,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
|
split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
||||||
):
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
@ -140,7 +155,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
|
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
|
||||||
|
|||||||
@ -4,6 +4,8 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from fastchat.conversation import Conversation
|
||||||
from server.model_workers.base import ApiModelWorker
|
from server.model_workers.base import ApiModelWorker
|
||||||
from server.utils import get_model_worker_config, get_httpx_client
|
from server.utils import get_model_worker_config, get_httpx_client
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
@ -78,16 +80,6 @@ class BaiChuanWorker(ApiModelWorker):
|
|||||||
kwargs.setdefault("context_len", 32768)
|
kwargs.setdefault("context_len", 32768)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
self.conv = conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
self.version = config.get("version",version)
|
self.version = config.get("version",version)
|
||||||
self.api_key = config.get("api_key")
|
self.api_key = config.get("api_key")
|
||||||
@ -127,6 +119,18 @@ class BaiChuanWorker(ApiModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
# TODO: 确认模板是否需要修改
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="",
|
||||||
|
messages=[],
|
||||||
|
roles=["user", "assistant"],
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from server.utils import MakeFastAPIOffline
|
from server.utils import MakeFastAPIOffline
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from fastchat.conversation import Conversation
|
||||||
from configs.basic_config import LOG_PATH
|
from configs.basic_config import LOG_PATH
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
@ -61,6 +62,9 @@ class ApiModelWorker(BaseModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
# print(params)
|
# print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
# help methods
|
# help methods
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
from server.utils import get_model_worker_config
|
from server.utils import get_model_worker_config
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from fastchat.conversation import Conversation
|
||||||
from server.model_workers.base import ApiModelWorker
|
from server.model_workers.base import ApiModelWorker
|
||||||
from configs.model_config import TEMPERATURE
|
from configs.model_config import TEMPERATURE
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
@ -74,15 +75,6 @@ class FangZhouWorker(ApiModelWorker):
|
|||||||
self.api_key = config.get("api_key")
|
self.api_key = config.get("api_key")
|
||||||
self.secret_key = config.get("secret_key")
|
self.secret_key = config.get("secret_key")
|
||||||
|
|
||||||
self.conv = conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant", "system"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
super().generate_stream_gate(params)
|
super().generate_stream_gate(params)
|
||||||
|
|
||||||
@ -107,6 +99,16 @@ class FangZhouWorker(ApiModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||||
|
messages=[],
|
||||||
|
roles=["user", "assistant", "system"],
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from fastchat.conversation import Conversation
|
||||||
from server.model_workers.base import ApiModelWorker
|
from server.model_workers.base import ApiModelWorker
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
import sys
|
import sys
|
||||||
@ -22,16 +23,6 @@ class MiniMaxWorker(ApiModelWorker):
|
|||||||
kwargs.setdefault("context_len", 16384)
|
kwargs.setdefault("context_len", 16384)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
self.conv = conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="",
|
|
||||||
messages=[],
|
|
||||||
roles=["USER", "BOT"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
def prompt_to_messages(self, prompt: str) -> List[Dict]:
|
def prompt_to_messages(self, prompt: str) -> List[Dict]:
|
||||||
result = super().prompt_to_messages(prompt)
|
result = super().prompt_to_messages(prompt)
|
||||||
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result]
|
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result]
|
||||||
@ -86,6 +77,17 @@ class MiniMaxWorker(ApiModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
# TODO: 确认模板是否需要修改
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="",
|
||||||
|
messages=[],
|
||||||
|
roles=["USER", "BOT"],
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from fastchat.conversation import Conversation
|
||||||
from server.model_workers.base import ApiModelWorker
|
from server.model_workers.base import ApiModelWorker
|
||||||
from configs.model_config import TEMPERATURE
|
from configs.model_config import TEMPERATURE
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
@ -120,16 +121,6 @@ class QianFanWorker(ApiModelWorker):
|
|||||||
kwargs.setdefault("context_len", 16384)
|
kwargs.setdefault("context_len", 16384)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
self.conv = conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
self.version = version
|
self.version = version
|
||||||
self.api_key = config.get("api_key")
|
self.api_key = config.get("api_key")
|
||||||
@ -162,6 +153,17 @@ class QianFanWorker(ApiModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
# TODO: 确认模板是否需要修改
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="",
|
||||||
|
messages=[],
|
||||||
|
roles=["user", "assistant"],
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from fastchat.conversation import Conversation
|
||||||
from configs import TEMPERATURE
|
from configs import TEMPERATURE
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import List, Literal, Dict
|
from typing import List, Literal, Dict
|
||||||
@ -68,15 +70,6 @@ class QwenWorker(ApiModelWorker):
|
|||||||
kwargs.setdefault("context_len", 16384)
|
kwargs.setdefault("context_len", 16384)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
self.conv = conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant", "system"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
self.api_key = config.get("api_key")
|
self.api_key = config.get("api_key")
|
||||||
self.version = version
|
self.version = version
|
||||||
@ -108,6 +101,17 @@ class QwenWorker(ApiModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
# TODO: 确认模板是否需要修改
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||||
|
messages=[],
|
||||||
|
roles=["user", "assistant", "system"],
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from fastchat.conversation import Conversation
|
||||||
from server.model_workers.base import ApiModelWorker
|
from server.model_workers.base import ApiModelWorker
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
import sys
|
import sys
|
||||||
@ -38,16 +39,6 @@ class XingHuoWorker(ApiModelWorker):
|
|||||||
kwargs.setdefault("context_len", 8192)
|
kwargs.setdefault("context_len", 8192)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
self.conv = conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="",
|
|
||||||
messages=[],
|
|
||||||
roles=["user", "assistant"],
|
|
||||||
sep="\n### ",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
# TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接
|
# TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接
|
||||||
|
|
||||||
@ -86,6 +77,17 @@ class XingHuoWorker(ApiModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
# TODO: 确认模板是否需要修改
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="",
|
||||||
|
messages=[],
|
||||||
|
roles=["user", "assistant"],
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from fastchat.conversation import Conversation
|
||||||
from server.model_workers.base import ApiModelWorker
|
from server.model_workers.base import ApiModelWorker
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
import sys
|
import sys
|
||||||
@ -23,16 +24,6 @@ class ChatGLMWorker(ApiModelWorker):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
|
||||||
self.conv = conv.Conversation(
|
|
||||||
name=self.model_names[0],
|
|
||||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
|
||||||
messages=[],
|
|
||||||
roles=["Human", "Assistant"],
|
|
||||||
sep="\n###",
|
|
||||||
stop_str="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
# TODO: 维护request_id
|
# TODO: 维护request_id
|
||||||
import zhipuai
|
import zhipuai
|
||||||
@ -59,6 +50,17 @@ class ChatGLMWorker(ApiModelWorker):
|
|||||||
print("embedding")
|
print("embedding")
|
||||||
# print(params)
|
# print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||||
|
messages=[],
|
||||||
|
roles=["Human", "Assistant"],
|
||||||
|
sep="\n###",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@ -477,7 +477,7 @@ def fschat_controller_address() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
|
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
|
||||||
if model := get_model_worker_config(model_name):
|
if model := get_model_worker_config(model_name): # TODO: depends fastchat
|
||||||
host = model["host"]
|
host = model["host"]
|
||||||
if host == "0.0.0.0":
|
if host == "0.0.0.0":
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
|
|||||||
44
startup.py
44
startup.py
@ -423,13 +423,13 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
|
|||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
def run_api_server(started_event: mp.Event = None):
|
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||||
from server.api import create_app
|
from server.api import create_app
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from server.utils import set_httpx_config
|
from server.utils import set_httpx_config
|
||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
app = create_app()
|
app = create_app(run_mode=run_mode)
|
||||||
_set_app_event(app, started_event)
|
_set_app_event(app, started_event)
|
||||||
|
|
||||||
host = API_SERVER["host"]
|
host = API_SERVER["host"]
|
||||||
@ -438,21 +438,27 @@ def run_api_server(started_event: mp.Event = None):
|
|||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
def run_webui(started_event: mp.Event = None):
|
def run_webui(started_event: mp.Event = None, run_mode: str = None):
|
||||||
from server.utils import set_httpx_config
|
from server.utils import set_httpx_config
|
||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
host = WEBUI_SERVER["host"]
|
host = WEBUI_SERVER["host"]
|
||||||
port = WEBUI_SERVER["port"]
|
port = WEBUI_SERVER["port"]
|
||||||
|
|
||||||
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
cmd = ["streamlit", "run", "webui.py",
|
||||||
"--server.address", host,
|
"--server.address", host,
|
||||||
"--server.port", str(port),
|
"--server.port", str(port),
|
||||||
"--theme.base", "light",
|
"--theme.base", "light",
|
||||||
"--theme.primaryColor", "#165dff",
|
"--theme.primaryColor", "#165dff",
|
||||||
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
||||||
"--theme.textColor", "#000000",
|
"--theme.textColor", "#000000",
|
||||||
])
|
]
|
||||||
|
if run_mode == "lite":
|
||||||
|
cmd += [
|
||||||
|
"--",
|
||||||
|
"lite",
|
||||||
|
]
|
||||||
|
p = subprocess.Popen(cmd)
|
||||||
started_event.set()
|
started_event.set()
|
||||||
p.wait()
|
p.wait()
|
||||||
|
|
||||||
@ -535,6 +541,13 @@ def parse_args() -> argparse.ArgumentParser:
|
|||||||
help="减少fastchat服务log信息",
|
help="减少fastchat服务log信息",
|
||||||
dest="quiet",
|
dest="quiet",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-i",
|
||||||
|
"--lite",
|
||||||
|
action="store_true",
|
||||||
|
help="以Lite模式运行:仅支持在线API的LLM对话、搜索引擎对话",
|
||||||
|
dest="lite",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args, parser
|
return args, parser
|
||||||
|
|
||||||
@ -596,6 +609,7 @@ async def start_main_server():
|
|||||||
|
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
manager = mp.Manager()
|
manager = mp.Manager()
|
||||||
|
run_mode = None
|
||||||
|
|
||||||
queue = manager.Queue()
|
queue = manager.Queue()
|
||||||
args, parser = parse_args()
|
args, parser = parse_args()
|
||||||
@ -621,6 +635,10 @@ async def start_main_server():
|
|||||||
args.api = False
|
args.api = False
|
||||||
args.webui = False
|
args.webui = False
|
||||||
|
|
||||||
|
if args.lite:
|
||||||
|
args.model_worker = False
|
||||||
|
run_mode = "lite"
|
||||||
|
|
||||||
dump_server_info(args=args)
|
dump_server_info(args=args)
|
||||||
|
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
@ -698,7 +716,7 @@ async def start_main_server():
|
|||||||
process = Process(
|
process = Process(
|
||||||
target=run_api_server,
|
target=run_api_server,
|
||||||
name=f"API Server",
|
name=f"API Server",
|
||||||
kwargs=dict(started_event=api_started),
|
kwargs=dict(started_event=api_started, run_mode=run_mode),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
processes["api"] = process
|
processes["api"] = process
|
||||||
@ -708,7 +726,7 @@ async def start_main_server():
|
|||||||
process = Process(
|
process = Process(
|
||||||
target=run_webui,
|
target=run_webui,
|
||||||
name=f"WEBUI Server",
|
name=f"WEBUI Server",
|
||||||
kwargs=dict(started_event=webui_started),
|
kwargs=dict(started_event=webui_started, run_mode=run_mode),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
processes["webui"] = process
|
processes["webui"] = process
|
||||||
|
|||||||
47
tests/api/test_server_state_api.py
Normal file
47
tests/api/test_server_state_api.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
root_path = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.append(str(root_path))
|
||||||
|
|
||||||
|
from webui_pages.utils import ApiRequest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pprint import pprint
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
api = ApiRequest()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_llm():
|
||||||
|
llm = api.get_default_llm_model()
|
||||||
|
|
||||||
|
print(llm)
|
||||||
|
assert isinstance(llm, tuple)
|
||||||
|
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_configs():
|
||||||
|
configs = api.get_server_configs()
|
||||||
|
pprint(configs, depth=2)
|
||||||
|
|
||||||
|
assert isinstance(configs, dict)
|
||||||
|
assert len(configs) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_search_engines():
|
||||||
|
engines = api.list_search_engines()
|
||||||
|
pprint(engines)
|
||||||
|
|
||||||
|
assert isinstance(engines, list)
|
||||||
|
assert len(engines) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
|
||||||
|
def test_get_prompt_template(type):
|
||||||
|
print(f"prompt template for: {type}")
|
||||||
|
template = api.get_prompt_template(type=type)
|
||||||
|
|
||||||
|
print(template)
|
||||||
|
assert isinstance(template, str)
|
||||||
|
assert len(template) > 0
|
||||||
81
tests/api/test_stream_chat_api_thread.py
Normal file
81
tests/api/test_stream_chat_api_thread.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
from configs import BING_SUBSCRIPTION_KEY
|
||||||
|
from server.utils import api_address
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
api_base_url = api_address()
|
||||||
|
|
||||||
|
|
||||||
|
def dump_input(d, title):
|
||||||
|
print("\n")
|
||||||
|
print("=" * 30 + title + " input " + "="*30)
|
||||||
|
pprint(d)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_output(r, title):
|
||||||
|
print("\n")
|
||||||
|
print("=" * 30 + title + " output" + "="*30)
|
||||||
|
for line in r.iter_content(None, decode_unicode=True):
|
||||||
|
print(line, end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'accept': 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||||
|
url = f"{api_base_url}{api}"
|
||||||
|
data = {
|
||||||
|
"query": "如何提问以获得高质量答案",
|
||||||
|
"knowledge_base_name": "samples",
|
||||||
|
"history": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "你好"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "你好,我是 ChatGLM"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
result = []
|
||||||
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
|
|
||||||
|
for line in response.iter_content(None, decode_unicode=True):
|
||||||
|
data = json.loads(line)
|
||||||
|
result.append(data)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread():
|
||||||
|
threads = []
|
||||||
|
times = []
|
||||||
|
pool = ThreadPoolExecutor()
|
||||||
|
start = time.time()
|
||||||
|
for i in range(10):
|
||||||
|
t = pool.submit(knowledge_chat)
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for r in as_completed(threads):
|
||||||
|
end = time.time()
|
||||||
|
times.append(end - start)
|
||||||
|
print("\nResult:\n")
|
||||||
|
pprint(r.result())
|
||||||
|
|
||||||
|
print("\nTime used:\n")
|
||||||
|
for x in times:
|
||||||
|
print(f"{x}")
|
||||||
19
webui.py
19
webui.py
@ -1,8 +1,9 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from streamlit_option_menu import option_menu
|
from streamlit_option_menu import option_menu
|
||||||
from webui_pages import *
|
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
from server.utils import api_address
|
from server.utils import api_address
|
||||||
|
|
||||||
@ -10,6 +11,8 @@ from server.utils import api_address
|
|||||||
api = ApiRequest(base_url=api_address())
|
api = ApiRequest(base_url=api_address())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
is_lite = "lite" in sys.argv
|
||||||
|
|
||||||
st.set_page_config(
|
st.set_page_config(
|
||||||
"Langchain-Chatchat WebUI",
|
"Langchain-Chatchat WebUI",
|
||||||
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
|
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
|
||||||
@ -26,11 +29,15 @@ if __name__ == "__main__":
|
|||||||
"icon": "chat",
|
"icon": "chat",
|
||||||
"func": dialogue_page,
|
"func": dialogue_page,
|
||||||
},
|
},
|
||||||
"知识库管理": {
|
|
||||||
"icon": "hdd-stack",
|
|
||||||
"func": knowledge_base_page,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
if not is_lite:
|
||||||
|
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
|
||||||
|
pages.update({
|
||||||
|
"知识库管理": {
|
||||||
|
"icon": "hdd-stack",
|
||||||
|
"func": knowledge_base_page,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.image(
|
st.image(
|
||||||
@ -57,4 +64,4 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
if selected_page in pages:
|
if selected_page in pages:
|
||||||
pages[selected_page]["func"](api)
|
pages[selected_page]["func"](api=api, is_lite=is_lite)
|
||||||
|
|||||||
@ -1,3 +0,0 @@
|
|||||||
from .dialogue import dialogue_page, chat_box
|
|
||||||
from .knowledge_base import knowledge_base_page
|
|
||||||
from .model_config import model_config_page
|
|
||||||
@ -1 +0,0 @@
|
|||||||
from .dialogue import dialogue_page, chat_box
|
|
||||||
@ -3,10 +3,11 @@ from webui_pages.utils import *
|
|||||||
from streamlit_chatbox import *
|
from streamlit_chatbox import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
|
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
"img",
|
"img",
|
||||||
@ -33,27 +34,9 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
|||||||
return chat_box.filter_history(history_len=history_len, filter=filter)
|
return chat_box.filter_history(history_len=history_len, filter=filter)
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm_model(api: ApiRequest) -> (str, bool):
|
def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||||
'''
|
|
||||||
从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回
|
|
||||||
返回类型为(model_name, is_local_model)
|
|
||||||
'''
|
|
||||||
running_models = api.list_running_models()
|
|
||||||
if not running_models:
|
|
||||||
return "", False
|
|
||||||
|
|
||||||
if LLM_MODEL in running_models:
|
|
||||||
return LLM_MODEL, True
|
|
||||||
|
|
||||||
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
|
||||||
if local_models:
|
|
||||||
return local_models[0], True
|
|
||||||
return list(running_models)[0], False
|
|
||||||
|
|
||||||
|
|
||||||
def dialogue_page(api: ApiRequest):
|
|
||||||
if not chat_box.chat_inited:
|
if not chat_box.chat_inited:
|
||||||
default_model = get_default_llm_model(api)[0]
|
default_model = api.get_default_llm_model()[0]
|
||||||
st.toast(
|
st.toast(
|
||||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
||||||
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
||||||
@ -70,13 +53,19 @@ def dialogue_page(api: ApiRequest):
|
|||||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
text = f"{text} 当前知识库: `{cur_kb}`。"
|
||||||
st.toast(text)
|
st.toast(text)
|
||||||
|
|
||||||
|
if is_lite:
|
||||||
|
dialogue_modes = ["LLM 对话",
|
||||||
|
"搜索引擎问答",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
dialogue_modes = ["LLM 对话",
|
||||||
|
"知识库问答",
|
||||||
|
"搜索引擎问答",
|
||||||
|
"自定义Agent问答",
|
||||||
|
]
|
||||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||||
["LLM 对话",
|
dialogue_modes,
|
||||||
"知识库问答",
|
index=0,
|
||||||
"搜索引擎问答",
|
|
||||||
"自定义Agent问答",
|
|
||||||
],
|
|
||||||
index=3,
|
|
||||||
on_change=on_mode_change,
|
on_change=on_mode_change,
|
||||||
key="dialogue_mode",
|
key="dialogue_mode",
|
||||||
)
|
)
|
||||||
@ -107,7 +96,7 @@ def dialogue_page(api: ApiRequest):
|
|||||||
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
|
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
llm_models = running_models + available_models
|
llm_models = running_models + available_models
|
||||||
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
|
index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
|
||||||
llm_model = st.selectbox("选择LLM模型:",
|
llm_model = st.selectbox("选择LLM模型:",
|
||||||
llm_models,
|
llm_models,
|
||||||
index,
|
index,
|
||||||
@ -116,9 +105,10 @@ def dialogue_page(api: ApiRequest):
|
|||||||
key="llm_model",
|
key="llm_model",
|
||||||
)
|
)
|
||||||
if (st.session_state.get("prev_llm_model") != llm_model
|
if (st.session_state.get("prev_llm_model") != llm_model
|
||||||
and not llm_model in config_models.get("online", {})
|
and not is_lite
|
||||||
and not llm_model in config_models.get("langchain", {})
|
and not llm_model in config_models.get("online", {})
|
||||||
and llm_model not in running_models):
|
and not llm_model in config_models.get("langchain", {})
|
||||||
|
and llm_model not in running_models):
|
||||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||||
prev_model = st.session_state.get("prev_llm_model")
|
prev_model = st.session_state.get("prev_llm_model")
|
||||||
r = api.change_llm_model(prev_model, llm_model)
|
r = api.change_llm_model(prev_model, llm_model)
|
||||||
@ -128,12 +118,18 @@ def dialogue_page(api: ApiRequest):
|
|||||||
st.success(msg)
|
st.success(msg)
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
|
|
||||||
index_prompt = {
|
if is_lite:
|
||||||
"LLM 对话": "llm_chat",
|
index_prompt = {
|
||||||
"自定义Agent问答": "agent_chat",
|
"LLM 对话": "llm_chat",
|
||||||
"搜索引擎问答": "search_engine_chat",
|
"搜索引擎问答": "search_engine_chat",
|
||||||
"知识库问答": "knowledge_base_chat",
|
}
|
||||||
}
|
else:
|
||||||
|
index_prompt = {
|
||||||
|
"LLM 对话": "llm_chat",
|
||||||
|
"自定义Agent问答": "agent_chat",
|
||||||
|
"搜索引擎问答": "search_engine_chat",
|
||||||
|
"知识库问答": "knowledge_base_chat",
|
||||||
|
}
|
||||||
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||||
prompt_template_name = prompt_templates_kb_list[0]
|
prompt_template_name = prompt_templates_kb_list[0]
|
||||||
if "prompt_template_select" not in st.session_state:
|
if "prompt_template_select" not in st.session_state:
|
||||||
@ -284,7 +280,8 @@ def dialogue_page(api: ApiRequest):
|
|||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
prompt_name=prompt_template_name,
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature):
|
temperature=temperature,
|
||||||
|
split_result=se_top_k>1):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
elif chunk := d.get("answer"):
|
elif chunk := d.get("answer"):
|
||||||
|
|||||||
@ -20,12 +20,11 @@ from configs import (
|
|||||||
logger, log_verbose,
|
logger, log_verbose,
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from server.utils import run_async, set_httpx_config, api_address, get_httpx_client
|
from server.utils import set_httpx_config, api_address, get_httpx_client
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
@ -213,7 +212,7 @@ class ApiRequest:
|
|||||||
if log_verbose:
|
if log_verbose:
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
return {"code": 500, "msg": msg}
|
return {"code": 500, "msg": msg, "data": None}
|
||||||
|
|
||||||
if value_func is None:
|
if value_func is None:
|
||||||
value_func = (lambda r: r)
|
value_func = (lambda r: r)
|
||||||
@ -233,9 +232,26 @@ class ApiRequest:
|
|||||||
return value_func(response)
|
return value_func(response)
|
||||||
|
|
||||||
# 服务器信息
|
# 服务器信息
|
||||||
def get_server_configs(self, **kwargs):
|
def get_server_configs(self, **kwargs) -> Dict:
|
||||||
response = self.post("/server/configs", **kwargs)
|
response = self.post("/server/configs", **kwargs)
|
||||||
return self._get_response_value(response, lambda r: r.json())
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
|
def list_search_engines(self, **kwargs) -> List:
|
||||||
|
response = self.post("/server/list_search_engines", **kwargs)
|
||||||
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
|
||||||
|
|
||||||
|
def get_prompt_template(
|
||||||
|
self,
|
||||||
|
type: str = "llm_chat",
|
||||||
|
name: str = "default",
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
data = {
|
||||||
|
"type": type,
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
response = self.post("/server/get_prompt_template", json=data, **kwargs)
|
||||||
|
return self._get_response_value(response, value_func=lambda r: r.text)
|
||||||
|
|
||||||
# 对话相关操作
|
# 对话相关操作
|
||||||
|
|
||||||
@ -251,16 +267,14 @@ class ApiRequest:
|
|||||||
'''
|
'''
|
||||||
对应api.py/chat/fastchat接口
|
对应api.py/chat/fastchat接口
|
||||||
'''
|
'''
|
||||||
msg = OpenAiChatMsgIn(**{
|
data = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"model": model,
|
"model": model,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
**kwargs,
|
}
|
||||||
})
|
|
||||||
|
|
||||||
data = msg.dict(exclude_unset=True, exclude_none=True)
|
|
||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
pprint(data)
|
pprint(data)
|
||||||
|
|
||||||
@ -268,6 +282,7 @@ class ApiRequest:
|
|||||||
"/chat/fastchat",
|
"/chat/fastchat",
|
||||||
json=data,
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response)
|
||||||
|
|
||||||
@ -380,6 +395,7 @@ class ApiRequest:
|
|||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
|
split_result: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/search_engine_chat接口
|
对应api.py/chat/search_engine_chat接口
|
||||||
@ -394,6 +410,7 @@ class ApiRequest:
|
|||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"prompt_name": prompt_name,
|
"prompt_name": prompt_name,
|
||||||
|
"split_result": split_result,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
@ -659,6 +676,43 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
|
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_llm_model(self) -> (str, bool):
|
||||||
|
'''
|
||||||
|
从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回
|
||||||
|
返回类型为(model_name, is_local_model)
|
||||||
|
'''
|
||||||
|
def ret_sync():
|
||||||
|
running_models = self.list_running_models()
|
||||||
|
if not running_models:
|
||||||
|
return "", False
|
||||||
|
|
||||||
|
if LLM_MODEL in running_models:
|
||||||
|
return LLM_MODEL, True
|
||||||
|
|
||||||
|
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||||
|
if local_models:
|
||||||
|
return local_models[0], True
|
||||||
|
return list(running_models)[0], False
|
||||||
|
|
||||||
|
async def ret_async():
|
||||||
|
running_models = await self.list_running_models()
|
||||||
|
if not running_models:
|
||||||
|
return "", False
|
||||||
|
|
||||||
|
if LLM_MODEL in running_models:
|
||||||
|
return LLM_MODEL, True
|
||||||
|
|
||||||
|
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||||
|
if local_models:
|
||||||
|
return local_models[0], True
|
||||||
|
return list(running_models)[0], False
|
||||||
|
|
||||||
|
if self._use_async:
|
||||||
|
return ret_async()
|
||||||
|
else:
|
||||||
|
return ret_sync()
|
||||||
|
|
||||||
def list_config_models(self) -> Dict[str, List[str]]:
|
def list_config_models(self) -> Dict[str, List[str]]:
|
||||||
'''
|
'''
|
||||||
获取服务器configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。
|
获取服务器configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user