mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话 (#1860)
* move get_default_llm_model from webui to ApiRequest 增加API接口及其测试用例: - /server/get_prompt_template: 获取服务器配置的 prompt 模板 - 增加知识库多线程访问测试用例 支持lite模式:无需安装torch等重依赖,通过在线API实现LLM对话和搜索引擎对话 * fix bug in server.api --------- Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
This commit is contained in:
parent
be67ea43d8
commit
03e55e11c4
60
requirements_lite.txt
Normal file
60
requirements_lite.txt
Normal file
@ -0,0 +1,60 @@
|
||||
langchain>=0.0.319
|
||||
fschat>=0.2.31
|
||||
openai
|
||||
# sentence_transformers
|
||||
# transformers>=4.33.0
|
||||
# torch>=2.0.1
|
||||
# torchvision
|
||||
# torchaudio
|
||||
fastapi>=0.103.1
|
||||
nltk~=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
starlette~=0.27.0
|
||||
pydantic~=1.10.11
|
||||
# unstructured[all-docs]>=0.10.4
|
||||
# python-magic-bin; sys_platform == 'win32'
|
||||
SQLAlchemy==2.0.19
|
||||
# faiss-cpu
|
||||
# accelerate
|
||||
# spacy
|
||||
# PyMuPDF==1.22.5
|
||||
# rapidocr_onnxruntime>=1.3.2
|
||||
|
||||
requests
|
||||
pathlib
|
||||
pytest
|
||||
# scikit-learn
|
||||
# numexpr
|
||||
# vllm==0.1.7; sys_platform == "linux"
|
||||
# online api libs
|
||||
zhipuai
|
||||
dashscope>=1.10.0 # qwen
|
||||
# qianfan
|
||||
# volcengine>=1.0.106 # fangzhou
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
# psycopg2
|
||||
# pgvector
|
||||
|
||||
numpy~=1.24.4
|
||||
pandas~=2.0.3
|
||||
streamlit>=1.26.0
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox>=1.1.9
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
watchdog
|
||||
tqdm
|
||||
websockets
|
||||
# tiktoken
|
||||
einops
|
||||
# scipy
|
||||
# transformers_stream_generator==0.0.4
|
||||
|
||||
# search engine libs
|
||||
duckduckgo-search
|
||||
metaphor-python
|
||||
strsimpy
|
||||
markdownify
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from server.chat import search_engine_chat
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
|
||||
133
server/api.py
133
server/api.py
@ -9,19 +9,19 @@ from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (completion,chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat, agent_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore, update_info)
|
||||
from server.chat.chat import chat
|
||||
from server.chat.openai_chat import openai_chat
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from server.chat.completion import completion
|
||||
from server.llm_api import (list_running_models, list_config_models,
|
||||
change_llm_model, stop_llm_model,
|
||||
get_model_config, list_search_engines)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs
|
||||
from typing import List
|
||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||
get_server_configs, get_prompt_template)
|
||||
from typing import List, Literal
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
@ -30,7 +30,7 @@ async def document():
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
def create_app():
|
||||
def create_app(run_mode: str = None):
|
||||
app = FastAPI(
|
||||
title="Langchain-Chatchat API Server",
|
||||
version=VERSION
|
||||
@ -47,7 +47,13 @@ def create_app():
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
mount_basic_routes(app)
|
||||
if run_mode != "lite":
|
||||
mount_knowledge_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
def mount_basic_routes(app: FastAPI):
|
||||
app.get("/",
|
||||
response_model=BaseResponse,
|
||||
summary="swagger 文档")(document)
|
||||
@ -65,14 +71,69 @@ def create_app():
|
||||
tags=["Chat"],
|
||||
summary="与llm模型对话(通过LLMChain)")(chat)
|
||||
|
||||
app.post("/chat/knowledge_base_chat",
|
||||
tags=["Chat"],
|
||||
summary="与知识库对话")(knowledge_base_chat)
|
||||
|
||||
app.post("/chat/search_engine_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(search_engine_chat)
|
||||
|
||||
# LLM模型相关接口
|
||||
app.post("/llm_model/list_running_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出当前已加载的模型",
|
||||
)(list_running_models)
|
||||
|
||||
app.post("/llm_model/list_config_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出configs已配置的模型",
|
||||
)(list_config_models)
|
||||
|
||||
app.post("/llm_model/get_model_config",
|
||||
tags=["LLM Model Management"],
|
||||
summary="获取模型配置(合并后)",
|
||||
)(get_model_config)
|
||||
|
||||
app.post("/llm_model/stop",
|
||||
tags=["LLM Model Management"],
|
||||
summary="停止指定的LLM模型(Model Worker)",
|
||||
)(stop_llm_model)
|
||||
|
||||
app.post("/llm_model/change",
|
||||
tags=["LLM Model Management"],
|
||||
summary="切换指定的LLM模型(Model Worker)",
|
||||
)(change_llm_model)
|
||||
|
||||
# 服务器相关接口
|
||||
app.post("/server/configs",
|
||||
tags=["Server State"],
|
||||
summary="获取服务器原始配置信息",
|
||||
)(get_server_configs)
|
||||
|
||||
app.post("/server/list_search_engines",
|
||||
tags=["Server State"],
|
||||
summary="获取服务器支持的搜索引擎",
|
||||
)(list_search_engines)
|
||||
|
||||
@app.post("/server/get_prompt_template",
|
||||
tags=["Server State"],
|
||||
summary="获取服务区配置的 prompt 模板")
|
||||
def get_server_prompt_template(
|
||||
type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat,search_engine_chat,agent_chat"),
|
||||
name: str = Body("default", description="模板名称"),
|
||||
) -> str:
|
||||
return get_prompt_template(type=type, name=name)
|
||||
|
||||
|
||||
def mount_knowledge_routes(app: FastAPI):
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from server.chat.agent_chat import agent_chat
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore, update_info)
|
||||
|
||||
app.post("/chat/knowledge_base_chat",
|
||||
tags=["Chat"],
|
||||
summary="与知识库对话")(knowledge_base_chat)
|
||||
|
||||
app.post("/chat/agent_chat",
|
||||
tags=["Chat"],
|
||||
summary="与agent对话")(agent_chat)
|
||||
@ -139,48 +200,6 @@ def create_app():
|
||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
|
||||
# LLM模型相关接口
|
||||
app.post("/llm_model/list_running_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出当前已加载的模型",
|
||||
)(list_running_models)
|
||||
|
||||
app.post("/llm_model/list_config_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出configs已配置的模型",
|
||||
)(list_config_models)
|
||||
|
||||
app.post("/llm_model/get_model_config",
|
||||
tags=["LLM Model Management"],
|
||||
summary="获取模型配置(合并后)",
|
||||
)(get_model_config)
|
||||
|
||||
app.post("/llm_model/stop",
|
||||
tags=["LLM Model Management"],
|
||||
summary="停止指定的LLM模型(Model Worker)",
|
||||
)(stop_llm_model)
|
||||
|
||||
app.post("/llm_model/change",
|
||||
tags=["LLM Model Management"],
|
||||
summary="切换指定的LLM模型(Model Worker)",
|
||||
)(change_llm_model)
|
||||
|
||||
# 服务器相关接口
|
||||
app.post("/server/configs",
|
||||
tags=["Server State"],
|
||||
summary="获取服务器原始配置信息",
|
||||
)(get_server_configs)
|
||||
|
||||
app.post("/server/list_search_engines",
|
||||
tags=["Server State"],
|
||||
summary="获取服务器支持的搜索引擎",
|
||||
)(list_search_engines)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
def run_api(host, port, **kwargs):
|
||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||
@ -205,6 +224,10 @@ if __name__ == "__main__":
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
|
||||
app = create_app()
|
||||
mount_knowledge_routes(app)
|
||||
|
||||
run_api(host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
from .chat import chat
|
||||
from .completion import completion
|
||||
from .knowledge_base_chat import knowledge_base_chat
|
||||
from .openai_chat import openai_chat
|
||||
from .search_engine_chat import search_engine_chat
|
||||
from .agent_chat import agent_chat
|
||||
@ -1,4 +1,5 @@
|
||||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||||
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||||
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
||||
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
||||
@ -12,13 +13,16 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from typing import List, Optional, Dict
|
||||
from server.chat.utils import History
|
||||
from langchain.docstore.document import Document
|
||||
import json
|
||||
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||||
from markdownify import markdownify
|
||||
|
||||
|
||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
"title": "env info is not found",
|
||||
@ -28,7 +32,7 @@ def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
return search.results(text, result_len)
|
||||
|
||||
|
||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||||
search = DuckDuckGoSearchAPIWrapper()
|
||||
return search.results(text, result_len)
|
||||
|
||||
@ -36,40 +40,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
def metaphor_search(
|
||||
text: str,
|
||||
result_len: int = SEARCH_ENGINE_TOP_K,
|
||||
splitter_name: str = "SpacyTextSplitter",
|
||||
split_result: bool = False,
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
) -> List[Dict]:
|
||||
from metaphor_python import Metaphor
|
||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||
from server.knowledge_base.utils import make_text_splitter
|
||||
|
||||
if not METAPHOR_API_KEY:
|
||||
return []
|
||||
|
||||
|
||||
client = Metaphor(METAPHOR_API_KEY)
|
||||
search = client.search(text, num_results=result_len, use_autoprompt=True)
|
||||
contents = search.get_contents().contents
|
||||
for x in contents:
|
||||
x.extract = markdownify(x.extract)
|
||||
|
||||
# metaphor 返回的内容都是长文本,需要分词再检索
|
||||
docs = [Document(page_content=x.extract,
|
||||
metadata={"link": x.url, "title": x.title})
|
||||
for x in contents]
|
||||
text_splitter = make_text_splitter(splitter_name=splitter_name,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
splitted_docs = text_splitter.split_documents(docs)
|
||||
|
||||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||||
if len(splitted_docs) > result_len:
|
||||
vs = memo_faiss_pool.new_vector_store()
|
||||
vs.add_documents(splitted_docs)
|
||||
splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0)
|
||||
if split_result:
|
||||
docs = [Document(page_content=x.extract,
|
||||
metadata={"link": x.url, "title": x.title})
|
||||
for x in contents]
|
||||
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
splitted_docs = text_splitter.split_documents(docs)
|
||||
|
||||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||||
if len(splitted_docs) > result_len:
|
||||
normal = NormalizedLevenshtein()
|
||||
for x in splitted_docs:
|
||||
x.metadata["score"] = normal.similarity(text, x.page_content)
|
||||
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||
splitted_docs = splitted_docs[:result_len]
|
||||
|
||||
docs = [{"snippet": x.page_content,
|
||||
"link": x.metadata["link"],
|
||||
"title": x.metadata["title"]}
|
||||
for x in splitted_docs]
|
||||
else:
|
||||
docs = [{"snippet": x.extract,
|
||||
"link": x.url,
|
||||
"title": x.title}
|
||||
for x in contents]
|
||||
|
||||
docs = [{"snippet": x.page_content,
|
||||
"link": x.metadata["link"],
|
||||
"title": x.metadata["title"]}
|
||||
for x in splitted_docs]
|
||||
return docs
|
||||
|
||||
|
||||
@ -93,9 +106,10 @@ async def lookup_search_engine(
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
split_result: bool = False,
|
||||
):
|
||||
search_engine = SEARCH_ENGINES[search_engine_name]
|
||||
results = await run_in_threadpool(search_engine, query, result_len=top_k)
|
||||
results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
|
||||
docs = search_result2docs(results)
|
||||
return docs
|
||||
|
||||
@ -116,6 +130,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
@ -140,7 +155,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
||||
docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
|
||||
|
||||
@ -4,6 +4,8 @@
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.utils import get_model_worker_config, get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
@ -78,16 +80,6 @@ class BaiChuanWorker(ApiModelWorker):
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# TODO: 确认模板是否需要修改
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
config = self.get_config()
|
||||
self.version = config.get("version",version)
|
||||
self.api_key = config.get("api_key")
|
||||
@ -127,6 +119,18 @@ class BaiChuanWorker(ApiModelWorker):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from configs.basic_config import LOG_PATH
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
@ -61,6 +62,9 @@ class ApiModelWorker(BaseModelWorker):
|
||||
print("embedding")
|
||||
# print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
raise NotImplementedError
|
||||
|
||||
# help methods
|
||||
def get_config(self):
|
||||
from server.utils import get_model_worker_config
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from configs.model_config import TEMPERATURE
|
||||
from fastchat import conversation as conv
|
||||
@ -74,15 +75,6 @@ class FangZhouWorker(ApiModelWorker):
|
||||
self.api_key = config.get("api_key")
|
||||
self.secret_key = config.get("secret_key")
|
||||
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
super().generate_stream_gate(params)
|
||||
|
||||
@ -107,6 +99,16 @@ class FangZhouWorker(ApiModelWorker):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
@ -22,16 +23,6 @@ class MiniMaxWorker(ApiModelWorker):
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# TODO: 确认模板是否需要修改
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["USER", "BOT"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
def prompt_to_messages(self, prompt: str) -> List[Dict]:
|
||||
result = super().prompt_to_messages(prompt)
|
||||
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result]
|
||||
@ -86,6 +77,17 @@ class MiniMaxWorker(ApiModelWorker):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["USER", "BOT"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from configs.model_config import TEMPERATURE
|
||||
from fastchat import conversation as conv
|
||||
@ -120,16 +121,6 @@ class QianFanWorker(ApiModelWorker):
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# TODO: 确认模板是否需要修改
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
config = self.get_config()
|
||||
self.version = version
|
||||
self.api_key = config.get("api_key")
|
||||
@ -162,6 +153,17 @@ class QianFanWorker(ApiModelWorker):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from configs import TEMPERATURE
|
||||
from http import HTTPStatus
|
||||
from typing import List, Literal, Dict
|
||||
@ -68,15 +70,6 @@ class QwenWorker(ApiModelWorker):
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# TODO: 确认模板是否需要修改
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
config = self.get_config()
|
||||
self.api_key = config.get("api_key")
|
||||
self.version = version
|
||||
@ -108,6 +101,17 @@ class QwenWorker(ApiModelWorker):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
@ -38,16 +39,6 @@ class XingHuoWorker(ApiModelWorker):
|
||||
kwargs.setdefault("context_len", 8192)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# TODO: 确认模板是否需要修改
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
# TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接
|
||||
|
||||
@ -86,6 +77,17 @@ class XingHuoWorker(ApiModelWorker):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
@ -23,16 +24,6 @@ class ChatGLMWorker(ApiModelWorker):
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["Human", "Assistant"],
|
||||
sep="\n###",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
# TODO: 维护request_id
|
||||
import zhipuai
|
||||
@ -59,6 +50,17 @@ class ChatGLMWorker(ApiModelWorker):
|
||||
print("embedding")
|
||||
# print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["Human", "Assistant"],
|
||||
sep="\n###",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@ -477,7 +477,7 @@ def fschat_controller_address() -> str:
|
||||
|
||||
|
||||
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
|
||||
if model := get_model_worker_config(model_name):
|
||||
if model := get_model_worker_config(model_name): # TODO: depends fastchat
|
||||
host = model["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
|
||||
44
startup.py
44
startup.py
@ -423,13 +423,13 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_api_server(started_event: mp.Event = None):
|
||||
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
|
||||
from server.api import create_app
|
||||
import uvicorn
|
||||
from server.utils import set_httpx_config
|
||||
set_httpx_config()
|
||||
|
||||
app = create_app()
|
||||
app = create_app(run_mode=run_mode)
|
||||
_set_app_event(app, started_event)
|
||||
|
||||
host = API_SERVER["host"]
|
||||
@ -438,21 +438,27 @@ def run_api_server(started_event: mp.Event = None):
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_webui(started_event: mp.Event = None):
|
||||
def run_webui(started_event: mp.Event = None, run_mode: str = None):
|
||||
from server.utils import set_httpx_config
|
||||
set_httpx_config()
|
||||
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
|
||||
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
||||
"--server.address", host,
|
||||
"--server.port", str(port),
|
||||
"--theme.base", "light",
|
||||
"--theme.primaryColor", "#165dff",
|
||||
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
||||
"--theme.textColor", "#000000",
|
||||
])
|
||||
cmd = ["streamlit", "run", "webui.py",
|
||||
"--server.address", host,
|
||||
"--server.port", str(port),
|
||||
"--theme.base", "light",
|
||||
"--theme.primaryColor", "#165dff",
|
||||
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
||||
"--theme.textColor", "#000000",
|
||||
]
|
||||
if run_mode == "lite":
|
||||
cmd += [
|
||||
"--",
|
||||
"lite",
|
||||
]
|
||||
p = subprocess.Popen(cmd)
|
||||
started_event.set()
|
||||
p.wait()
|
||||
|
||||
@ -535,6 +541,13 @@ def parse_args() -> argparse.ArgumentParser:
|
||||
help="减少fastchat服务log信息",
|
||||
dest="quiet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--lite",
|
||||
action="store_true",
|
||||
help="以Lite模式运行:仅支持在线API的LLM对话、搜索引擎对话",
|
||||
dest="lite",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args, parser
|
||||
|
||||
@ -596,6 +609,7 @@ async def start_main_server():
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
manager = mp.Manager()
|
||||
run_mode = None
|
||||
|
||||
queue = manager.Queue()
|
||||
args, parser = parse_args()
|
||||
@ -621,6 +635,10 @@ async def start_main_server():
|
||||
args.api = False
|
||||
args.webui = False
|
||||
|
||||
if args.lite:
|
||||
args.model_worker = False
|
||||
run_mode = "lite"
|
||||
|
||||
dump_server_info(args=args)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
@ -698,7 +716,7 @@ async def start_main_server():
|
||||
process = Process(
|
||||
target=run_api_server,
|
||||
name=f"API Server",
|
||||
kwargs=dict(started_event=api_started),
|
||||
kwargs=dict(started_event=api_started, run_mode=run_mode),
|
||||
daemon=True,
|
||||
)
|
||||
processes["api"] = process
|
||||
@ -708,7 +726,7 @@ async def start_main_server():
|
||||
process = Process(
|
||||
target=run_webui,
|
||||
name=f"WEBUI Server",
|
||||
kwargs=dict(started_event=webui_started),
|
||||
kwargs=dict(started_event=webui_started, run_mode=run_mode),
|
||||
daemon=True,
|
||||
)
|
||||
processes["webui"] = process
|
||||
|
||||
47
tests/api/test_server_state_api.py
Normal file
47
tests/api/test_server_state_api.py
Normal file
@ -0,0 +1,47 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
import pytest
|
||||
from pprint import pprint
|
||||
from typing import List
|
||||
|
||||
|
||||
api = ApiRequest()
|
||||
|
||||
|
||||
def test_get_default_llm():
|
||||
llm = api.get_default_llm_model()
|
||||
|
||||
print(llm)
|
||||
assert isinstance(llm, tuple)
|
||||
assert isinstance(llm[0], str) and isinstance(llm[1], bool)
|
||||
|
||||
|
||||
def test_server_configs():
|
||||
configs = api.get_server_configs()
|
||||
pprint(configs, depth=2)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
assert len(configs) > 0
|
||||
|
||||
|
||||
def test_list_search_engines():
|
||||
engines = api.list_search_engines()
|
||||
pprint(engines)
|
||||
|
||||
assert isinstance(engines, list)
|
||||
assert len(engines) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
|
||||
def test_get_prompt_template(type):
|
||||
print(f"prompt template for: {type}")
|
||||
template = api.get_prompt_template(type=type)
|
||||
|
||||
print(template)
|
||||
assert isinstance(template, str)
|
||||
assert len(template) > 0
|
||||
81
tests/api/test_stream_chat_api_thread.py
Normal file
81
tests/api/test_stream_chat_api_thread.py
Normal file
@ -0,0 +1,81 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from configs import BING_SUBSCRIPTION_KEY
|
||||
from server.utils import api_address
|
||||
|
||||
from pprint import pprint
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import time
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def dump_input(d, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " input " + "="*30)
|
||||
pprint(d)
|
||||
|
||||
|
||||
def dump_output(r, title):
|
||||
print("\n")
|
||||
print("=" * 30 + title + " output" + "="*30)
|
||||
for line in r.iter_content(None, decode_unicode=True):
|
||||
print(line, end="", flush=True)
|
||||
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
|
||||
def knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||
url = f"{api_base_url}{api}"
|
||||
data = {
|
||||
"query": "如何提问以获得高质量答案",
|
||||
"knowledge_base_name": "samples",
|
||||
"history": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "你好,我是 ChatGLM"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
result = []
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line)
|
||||
result.append(data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def test_thread():
|
||||
threads = []
|
||||
times = []
|
||||
pool = ThreadPoolExecutor()
|
||||
start = time.time()
|
||||
for i in range(10):
|
||||
t = pool.submit(knowledge_chat)
|
||||
threads.append(t)
|
||||
|
||||
for r in as_completed(threads):
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
print("\nResult:\n")
|
||||
pprint(r.result())
|
||||
|
||||
print("\nTime used:\n")
|
||||
for x in times:
|
||||
print(f"{x}")
|
||||
19
webui.py
19
webui.py
@ -1,8 +1,9 @@
|
||||
import streamlit as st
|
||||
from webui_pages.utils import *
|
||||
from streamlit_option_menu import option_menu
|
||||
from webui_pages import *
|
||||
from webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||
import os
|
||||
import sys
|
||||
from configs import VERSION
|
||||
from server.utils import api_address
|
||||
|
||||
@ -10,6 +11,8 @@ from server.utils import api_address
|
||||
api = ApiRequest(base_url=api_address())
|
||||
|
||||
if __name__ == "__main__":
|
||||
is_lite = "lite" in sys.argv
|
||||
|
||||
st.set_page_config(
|
||||
"Langchain-Chatchat WebUI",
|
||||
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
|
||||
@ -26,11 +29,15 @@ if __name__ == "__main__":
|
||||
"icon": "chat",
|
||||
"func": dialogue_page,
|
||||
},
|
||||
"知识库管理": {
|
||||
"icon": "hdd-stack",
|
||||
"func": knowledge_base_page,
|
||||
},
|
||||
}
|
||||
if not is_lite:
|
||||
from webui_pages.knowledge_base.knowledge_base import knowledge_base_page
|
||||
pages.update({
|
||||
"知识库管理": {
|
||||
"icon": "hdd-stack",
|
||||
"func": knowledge_base_page,
|
||||
},
|
||||
})
|
||||
|
||||
with st.sidebar:
|
||||
st.image(
|
||||
@ -57,4 +64,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if selected_page in pages:
|
||||
pages[selected_page]["func"](api)
|
||||
pages[selected_page]["func"](api=api, is_lite=is_lite)
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from .dialogue import dialogue_page, chat_box
|
||||
from .knowledge_base import knowledge_base_page
|
||||
from .model_config import model_config_page
|
||||
@ -1 +0,0 @@
|
||||
from .dialogue import dialogue_page, chat_box
|
||||
@ -3,10 +3,11 @@ from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from datetime import datetime
|
||||
import os
|
||||
from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
"img",
|
||||
@ -33,27 +34,9 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
||||
return chat_box.filter_history(history_len=history_len, filter=filter)
|
||||
|
||||
|
||||
def get_default_llm_model(api: ApiRequest) -> (str, bool):
|
||||
'''
|
||||
从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回
|
||||
返回类型为(model_name, is_local_model)
|
||||
'''
|
||||
running_models = api.list_running_models()
|
||||
if not running_models:
|
||||
return "", False
|
||||
|
||||
if LLM_MODEL in running_models:
|
||||
return LLM_MODEL, True
|
||||
|
||||
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||
if local_models:
|
||||
return local_models[0], True
|
||||
return list(running_models)[0], False
|
||||
|
||||
|
||||
def dialogue_page(api: ApiRequest):
|
||||
def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
if not chat_box.chat_inited:
|
||||
default_model = get_default_llm_model(api)[0]
|
||||
default_model = api.get_default_llm_model()[0]
|
||||
st.toast(
|
||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
||||
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
||||
@ -70,13 +53,19 @@ def dialogue_page(api: ApiRequest):
|
||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
||||
st.toast(text)
|
||||
|
||||
if is_lite:
|
||||
dialogue_modes = ["LLM 对话",
|
||||
"搜索引擎问答",
|
||||
]
|
||||
else:
|
||||
dialogue_modes = ["LLM 对话",
|
||||
"知识库问答",
|
||||
"搜索引擎问答",
|
||||
"自定义Agent问答",
|
||||
]
|
||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||
["LLM 对话",
|
||||
"知识库问答",
|
||||
"搜索引擎问答",
|
||||
"自定义Agent问答",
|
||||
],
|
||||
index=3,
|
||||
dialogue_modes,
|
||||
index=0,
|
||||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
)
|
||||
@ -107,7 +96,7 @@ def dialogue_page(api: ApiRequest):
|
||||
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
|
||||
available_models.append(k)
|
||||
llm_models = running_models + available_models
|
||||
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
|
||||
index = llm_models.index(st.session_state.get("cur_llm_model", api.get_default_llm_model()[0]))
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_models,
|
||||
index,
|
||||
@ -116,9 +105,10 @@ def dialogue_page(api: ApiRequest):
|
||||
key="llm_model",
|
||||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not llm_model in config_models.get("online", {})
|
||||
and not llm_model in config_models.get("langchain", {})
|
||||
and llm_model not in running_models):
|
||||
and not is_lite
|
||||
and not llm_model in config_models.get("online", {})
|
||||
and not llm_model in config_models.get("langchain", {})
|
||||
and llm_model not in running_models):
|
||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||
prev_model = st.session_state.get("prev_llm_model")
|
||||
r = api.change_llm_model(prev_model, llm_model)
|
||||
@ -128,12 +118,18 @@ def dialogue_page(api: ApiRequest):
|
||||
st.success(msg)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
index_prompt = {
|
||||
"LLM 对话": "llm_chat",
|
||||
"自定义Agent问答": "agent_chat",
|
||||
"搜索引擎问答": "search_engine_chat",
|
||||
"知识库问答": "knowledge_base_chat",
|
||||
}
|
||||
if is_lite:
|
||||
index_prompt = {
|
||||
"LLM 对话": "llm_chat",
|
||||
"搜索引擎问答": "search_engine_chat",
|
||||
}
|
||||
else:
|
||||
index_prompt = {
|
||||
"LLM 对话": "llm_chat",
|
||||
"自定义Agent问答": "agent_chat",
|
||||
"搜索引擎问答": "search_engine_chat",
|
||||
"知识库问答": "knowledge_base_chat",
|
||||
}
|
||||
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||
prompt_template_name = prompt_templates_kb_list[0]
|
||||
if "prompt_template_select" not in st.session_state:
|
||||
@ -284,7 +280,8 @@ def dialogue_page(api: ApiRequest):
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
temperature=temperature,
|
||||
split_result=se_top_k>1):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
|
||||
@ -20,12 +20,11 @@ from configs import (
|
||||
logger, log_verbose,
|
||||
)
|
||||
import httpx
|
||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from server.utils import run_async, set_httpx_config, api_address, get_httpx_client
|
||||
from server.utils import set_httpx_config, api_address, get_httpx_client
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
@ -213,7 +212,7 @@ class ApiRequest:
|
||||
if log_verbose:
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return {"code": 500, "msg": msg}
|
||||
return {"code": 500, "msg": msg, "data": None}
|
||||
|
||||
if value_func is None:
|
||||
value_func = (lambda r: r)
|
||||
@ -233,9 +232,26 @@ class ApiRequest:
|
||||
return value_func(response)
|
||||
|
||||
# 服务器信息
|
||||
def get_server_configs(self, **kwargs):
|
||||
def get_server_configs(self, **kwargs) -> Dict:
|
||||
response = self.post("/server/configs", **kwargs)
|
||||
return self._get_response_value(response, lambda r: r.json())
|
||||
return self._get_response_value(response, as_json=True)
|
||||
|
||||
def list_search_engines(self, **kwargs) -> List:
|
||||
response = self.post("/server/list_search_engines", **kwargs)
|
||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
|
||||
|
||||
def get_prompt_template(
|
||||
self,
|
||||
type: str = "llm_chat",
|
||||
name: str = "default",
|
||||
**kwargs,
|
||||
) -> str:
|
||||
data = {
|
||||
"type": type,
|
||||
"name": name,
|
||||
}
|
||||
response = self.post("/server/get_prompt_template", json=data, **kwargs)
|
||||
return self._get_response_value(response, value_func=lambda r: r.text)
|
||||
|
||||
# 对话相关操作
|
||||
|
||||
@ -251,16 +267,14 @@ class ApiRequest:
|
||||
'''
|
||||
对应api.py/chat/fastchat接口
|
||||
'''
|
||||
msg = OpenAiChatMsgIn(**{
|
||||
data = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs,
|
||||
})
|
||||
}
|
||||
|
||||
data = msg.dict(exclude_unset=True, exclude_none=True)
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
@ -268,6 +282,7 @@ class ApiRequest:
|
||||
"/chat/fastchat",
|
||||
json=data,
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
@ -380,6 +395,7 @@ class ApiRequest:
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
split_result: bool = False,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/search_engine_chat接口
|
||||
@ -394,6 +410,7 @@ class ApiRequest:
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"prompt_name": prompt_name,
|
||||
"split_result": split_result,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
@ -659,6 +676,43 @@ class ApiRequest:
|
||||
)
|
||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
|
||||
|
||||
|
||||
def get_default_llm_model(self) -> (str, bool):
|
||||
'''
|
||||
从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回
|
||||
返回类型为(model_name, is_local_model)
|
||||
'''
|
||||
def ret_sync():
|
||||
running_models = self.list_running_models()
|
||||
if not running_models:
|
||||
return "", False
|
||||
|
||||
if LLM_MODEL in running_models:
|
||||
return LLM_MODEL, True
|
||||
|
||||
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||
if local_models:
|
||||
return local_models[0], True
|
||||
return list(running_models)[0], False
|
||||
|
||||
async def ret_async():
|
||||
running_models = await self.list_running_models()
|
||||
if not running_models:
|
||||
return "", False
|
||||
|
||||
if LLM_MODEL in running_models:
|
||||
return LLM_MODEL, True
|
||||
|
||||
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||
if local_models:
|
||||
return local_models[0], True
|
||||
return list(running_models)[0], False
|
||||
|
||||
if self._use_async:
|
||||
return ret_async()
|
||||
else:
|
||||
return ret_sync()
|
||||
|
||||
def list_config_models(self) -> Dict[str, List[str]]:
|
||||
'''
|
||||
获取服务器configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user