实现Api和WEBUI的前后端分离 (#1772)

* update ApiRequest: 删除no_remote_api本地调用模式;支持同步/异步调用
* 实现API和WEBUI的分离:
- API运行服务器上的配置通过/llm_model/get_model_config、/server/configs接口提供,WEBUI运行机器上的配置项仅作为代码内部默认值使用
- 服务器可用的搜索引擎通过/server/list_search_engines提供
- WEBUI可选LLM列表中只列出在FSCHAT_MODEL_WORKERS中配置的模型
- 修改WEBUI中默认LLM_MODEL获取方式,改为从api端读取
- 删除knowledge_base_chat中`local_doc_url`参数

其它修改:
- 删除多余的kb_config.py.exmaple(名称错误)
- server_config中默认关闭vllm
- server_config中默认注释除智谱AI之外的在线模型
- 修改requests从系统获取的代理,避免model worker注册错误

* 修正:
- api.list_config_models返回模型原始配置
- api.list_config_models和api.get_model_config中过滤online api模型的敏感信息
- 将GPT等直接访问的模型列入WEBUI可选模型列表

其它:
- 指定langchain==0.3.313, fschat==0.2.30, langchain-experimental==0.0.30
This commit is contained in:
liunux4odoo 2023-10-17 16:52:07 +08:00 committed by GitHub
parent 94977c7ab1
commit 9ce328fea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 521 additions and 644 deletions

View File

@ -1,99 +0,0 @@
import os
# 默认向量库类型。可选faiss, milvus, pg.
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量针对FAISS
CACHED_VS_NUM = 1
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 50
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
# Bing 搜索必备变量
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
# 具体申请方式请见
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
# 使用python创建bing api 搜索实例详见:
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 注意不是bing Webmaster Tools的api key
# 此外如果是在服务器上报Failed to establish a new connection: [Errno 110] Connection timed out
# 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG
BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 可选向量库类型及对应配置
kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", ## 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "gpt2",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"

View File

@ -32,6 +32,7 @@ FSCHAT_OPENAI_API = {
# fastchat model_worker server # fastchat model_worker server
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。 # 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL # 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
# 必须在这里添加的模型才会出现在WEBUI中可选模型列表里LLM_MODEL会自动添加
FSCHAT_MODEL_WORKERS = { FSCHAT_MODEL_WORKERS = {
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。 # 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
"default": { "default": {
@ -39,7 +40,8 @@ FSCHAT_MODEL_WORKERS = {
"port": 20002, "port": 20002,
"device": LLM_DEVICE, "device": LLM_DEVICE,
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题参见doc/FAQ # False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题参见doc/FAQ
"infer_turbo": "vllm" if sys.platform.startswith("linux") else False, # vllm对一些模型支持还不成熟暂时默认关闭
"infer_turbo": False,
# model_worker多卡加载需要配置的参数 # model_worker多卡加载需要配置的参数
# "gpus": None, # 使用的GPU以str的格式指定如"0,1"如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定 # "gpus": None, # 使用的GPU以str的格式指定如"0,1"如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定
@ -97,24 +99,24 @@ FSCHAT_MODEL_WORKERS = {
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口 "zhipu-api": { # 请为每个要运行的在线API设置不同的端口
"port": 21001, "port": 21001,
}, },
"minimax-api": { # "minimax-api": {
"port": 21002, # "port": 21002,
}, # },
"xinghuo-api": { # "xinghuo-api": {
"port": 21003, # "port": 21003,
}, # },
"qianfan-api": { # "qianfan-api": {
"port": 21004, # "port": 21004,
}, # },
"fangzhou-api": { # "fangzhou-api": {
"port": 21005, # "port": 21005,
}, # },
"qwen-api": { # "qwen-api": {
"port": 21006, # "port": 21006,
}, # },
"baichuan-api": { # "baichuan-api": {
"port": 21007, # "port": 21007,
}, # },
} }
# fastchat multi model worker server # fastchat multi model worker server

View File

@ -1,5 +1,6 @@
langchain>=0.0.310 langchain==0.0.313
fschat[model_worker]>=0.2.30 langchain-experimental==0.0.30
fschat[model_worker]==0.2.30
openai openai
sentence_transformers sentence_transformers
transformers>=4.34 transformers>=4.34

View File

@ -1,5 +1,6 @@
langchain>=0.0.310 langchain==0.0.313
fschat[model_worker]>=0.2.30 langchain-experimental==0.0.30
fschat[model_worker]==0.2.30
openai openai
sentence_transformers>=2.2.2 sentence_transformers>=2.2.2
transformers>=4.34 transformers>=4.34

View File

@ -17,8 +17,10 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store, update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore) search_docs, DocumentWithScore)
from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model from server.llm_api import (list_running_models, list_config_models,
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline change_llm_model, stop_llm_model,
get_model_config, list_search_engines)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs
from typing import List from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -139,6 +141,11 @@ def create_app():
summary="列出configs已配置的模型", summary="列出configs已配置的模型",
)(list_config_models) )(list_config_models)
app.post("/llm_model/get_model_config",
tags=["LLM Model Management"],
summary="获取模型配置(合并后)",
)(get_model_config)
app.post("/llm_model/stop", app.post("/llm_model/stop",
tags=["LLM Model Management"], tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)", summary="停止指定的LLM模型Model Worker)",
@ -149,6 +156,17 @@ def create_app():
summary="切换指定的LLM模型Model Worker)", summary="切换指定的LLM模型Model Worker)",
)(change_llm_model) )(change_llm_model)
# 服务器相关接口
app.post("/server/configs",
tags=["Server State"],
summary="获取服务器原始配置信息",
)(get_server_configs)
app.post("/server/list_search_engines",
tags=["Server State"],
summary="获取服务器支持的搜索引擎",
)(list_search_engines)
return app return app

View File

@ -33,7 +33,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。 max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None, request: Request = None,
): ):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
@ -74,9 +73,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
source_documents = [] source_documents = []
for inum, doc in enumerate(docs): for inum, doc in enumerate(docs):
filename = os.path.split(doc.metadata["source"])[-1] filename = os.path.split(doc.metadata["source"])[-1]
if local_doc_url:
url = "file://" + doc.metadata["source"]
else:
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
url = f"{request.base_url}knowledge_base/download_doc?" + parameters url = f"{request.base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""

View File

@ -1,7 +1,7 @@
from fastapi import Body from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
def list_running_models( def list_running_models(
@ -9,19 +9,21 @@ def list_running_models(
placeholder: str = Body(None, description="该参数未使用,占位用"), placeholder: str = Body(None, description="该参数未使用,占位用"),
) -> BaseResponse: ) -> BaseResponse:
''' '''
从fastchat controller获取已加载模型列表 从fastchat controller获取已加载模型列表及其配置项
''' '''
try: try:
controller_address = controller_address or fschat_controller_address() controller_address = controller_address or fschat_controller_address()
with get_httpx_client() as client: with get_httpx_client() as client:
r = client.post(controller_address + "/list_models") r = client.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"]) models = r.json()["models"]
data = {m: get_model_worker_config(m) for m in models}
return BaseResponse(data=data)
except Exception as e: except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}', logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
return BaseResponse( return BaseResponse(
code=500, code=500,
data=[], data={},
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
@ -29,7 +31,38 @@ def list_config_models() -> BaseResponse:
''' '''
从本地获取configs中配置的模型列表 从本地获取configs中配置的模型列表
''' '''
return BaseResponse(data=list_llm_models()) configs = list_config_llm_models()
# 删除ONLINE_MODEL配置中的敏感信息
for config in configs["online"].values():
del_keys = set(["worker_class"])
for k in config:
if "key" in k.lower() or "secret" in k.lower():
del_keys.add(k)
for k in del_keys:
config.pop(k, None)
return BaseResponse(data=configs)
def get_model_config(
model_name: str = Body(description="配置中LLM模型的名称"),
placeholder: str = Body(description="占位用,无实际效果")
) -> BaseResponse:
'''
获取LLM模型配置项合并后的
'''
config = get_model_worker_config(model_name=model_name)
# 删除ONLINE_MODEL配置中的敏感信息
del_keys = set(["worker_class"])
for k in config:
if "key" in k.lower() or "secret" in k.lower():
del_keys.add(k)
for k in del_keys:
config.pop(k, None)
return BaseResponse(data=config)
def stop_llm_model( def stop_llm_model(
@ -79,3 +112,9 @@ def change_llm_model(
return BaseResponse( return BaseResponse(
code=500, code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
def list_search_engines() -> BaseResponse:
from server.chat.search_engine_chat import SEARCH_ENGINES
return BaseResponse(data=list(SEARCH_ENGINES))

View File

@ -258,17 +258,18 @@ def list_embed_models() -> List[str]:
return list(MODEL_PATH["embed_model"]) return list(MODEL_PATH["embed_model"])
def list_llm_models() -> Dict[str, List[str]]: def list_config_llm_models() -> Dict[str, Dict]:
''' '''
get names of configured llm models with different types. get configured llm models with different types.
return [(model_name, config_type), ...] return [(model_name, config_type), ...]
''' '''
workers = list(FSCHAT_MODEL_WORKERS) workers = list(FSCHAT_MODEL_WORKERS)
if "default" in workers: if LLM_MODEL not in workers:
workers.remove("default") workers.insert(0, LLM_MODEL)
return { return {
"local": list(MODEL_PATH["llm_model"]), "local": MODEL_PATH["llm_model"],
"online": list(ONLINE_LLM_MODEL), "online": ONLINE_LLM_MODEL,
"worker": workers, "worker": workers,
} }
@ -306,7 +307,7 @@ def get_model_worker_config(model_name: str = None) -> dict:
加载model worker的配置项 加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"] 优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
''' '''
from configs.model_config import ONLINE_LLM_MODEL from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
from configs.server_config import FSCHAT_MODEL_WORKERS from configs.server_config import FSCHAT_MODEL_WORKERS
from server import model_workers from server import model_workers
@ -324,7 +325,8 @@ def get_model_worker_config(model_name: str = None) -> dict:
msg = f"在线模型 {model_name} 的provider没有正确配置" msg = f"在线模型 {model_name} 的provider没有正确配置"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
# 本地模型
if model_name in MODEL_PATH["llm_model"]:
config["model_path"] = get_model_path(model_name) config["model_path"] = get_model_path(model_name)
config["device"] = llm_device(config.get("device")) config["device"] = llm_device(config.get("device"))
return config return config
@ -449,11 +451,11 @@ def set_httpx_config(
# TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。 # TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。
# patch requests to use custom proxies instead of system settings # patch requests to use custom proxies instead of system settings
# def _get_proxies(): def _get_proxies():
# return {} return proxies
# import urllib.request import urllib.request
# urllib.request.getproxies = _get_proxies urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch # 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
@ -557,3 +559,35 @@ def get_httpx_client(
return httpx.AsyncClient(**kwargs) return httpx.AsyncClient(**kwargs)
else: else:
return httpx.Client(**kwargs) return httpx.Client(**kwargs)
def get_server_configs() -> Dict:
'''
获取configs中的原始配置项供前端使用
'''
from configs.kb_config import (
DEFAULT_VS_TYPE,
CHUNK_SIZE,
OVERLAP_SIZE,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
ZH_TITLE_ENHANCE,
text_splitter_dict,
TEXT_SPLITTER_NAME,
)
from configs.model_config import (
LLM_MODEL,
EMBEDDING_MODEL,
HISTORY_LEN,
TEMPERATURE,
)
from configs.prompt_config import PROMPT_TEMPLATES
_custom = {
"controller_address": fschat_controller_address(),
"openai_api_address": fschat_openai_api_address(),
"api_address": api_address(),
}
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}

View File

@ -14,7 +14,7 @@ from pprint import pprint
api_base_url = api_address() api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False) api: ApiRequest = ApiRequest(api_base_url)
kb = "kb_for_api_test" kb = "kb_for_api_test"

View File

@ -32,7 +32,7 @@ def get_running_models(api="/llm_model/list_models"):
return [] return []
def test_running_models(api="/llm_model/list_models"): def test_running_models(api="/llm_model/list_running_models"):
url = api_base_url + api url = api_base_url + api
r = requests.post(url) r = requests.post(url)
assert r.status_code == 200 assert r.status_code == 200
@ -48,7 +48,7 @@ def test_running_models(api="/llm_model/list_models"):
# r = requests.post(url, json={""}) # r = requests.post(url, json={""})
def test_change_model(api="/llm_model/change"): def test_change_model(api="/llm_model/change_model"):
url = api_base_url + api url = api_base_url + api
running_models = get_running_models() running_models = get_running_models()

View File

@ -22,9 +22,10 @@ if __name__ == "__main__":
) )
if not chat_box.chat_inited: if not chat_box.chat_inited:
running_models = api.list_running_models()
st.toast( st.toast(
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n" f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了." f"当前运行中的模型`{running_models}`, 您可以开始提问了."
) )
pages = { pages = {

View File

@ -2,11 +2,11 @@ import streamlit as st
from webui_pages.utils import * from webui_pages.utils import *
from streamlit_chatbox import * from streamlit_chatbox import *
from datetime import datetime from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
import os import os
from configs import LLM_MODEL, TEMPERATURE from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN
from server.utils import get_model_worker_config
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
"img", "img",
@ -15,9 +15,6 @@ chat_box = ChatBox(
) )
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
''' '''
返回消息历史 返回消息历史
@ -38,6 +35,26 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
return chat_box.filter_history(history_len=history_len, filter=filter) return chat_box.filter_history(history_len=history_len, filter=filter)
def get_default_llm_model(api: ApiRequest) -> (str, bool):
'''
从服务器上获取当前运行的LLM模型如果本机配置的LLM_MODEL属于本地模型且在其中则优先返回
返回类型为model_name, is_local_model
'''
running_models = api.list_running_models()
if not running_models:
return "", False
if LLM_MODEL in running_models:
return LLM_MODEL, True
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
if local_models:
return local_models[0], True
return running_models[0], False
def dialogue_page(api: ApiRequest): def dialogue_page(api: ApiRequest):
chat_box.init_session() chat_box.init_session()
@ -51,7 +68,6 @@ def dialogue_page(api: ApiRequest):
if cur_kb: if cur_kb:
text = f"{text} 当前知识库: `{cur_kb}`。" text = f"{text} 当前知识库: `{cur_kb}`。"
st.toast(text) st.toast(text)
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
dialogue_mode = st.selectbox("请选择对话模式:", dialogue_mode = st.selectbox("请选择对话模式:",
["LLM 对话", ["LLM 对话",
@ -65,7 +81,7 @@ def dialogue_page(api: ApiRequest):
) )
def on_llm_change(): def on_llm_change():
config = get_model_worker_config(llm_model) config = api.get_model_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型 if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = st.session_state.llm_model st.session_state["cur_llm_model"] = st.session_state.llm_model
@ -75,15 +91,20 @@ def dialogue_page(api: ApiRequest):
return f"{x} (Running)" return f"{x} (Running)"
return x return x
running_models = api.list_running_models() running_models = list(api.list_running_models())
available_models = [] available_models = []
config_models = api.list_config_models() config_models = api.list_config_models()
for models in config_models.values(): worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
for m in models: for m in worker_models:
if m not in running_models: if m not in running_models and m != "default":
available_models.append(m) available_models.append(m)
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型如GPT
if not v.get("provider") and k not in running_models:
print(k, v)
available_models.append(k)
llm_models = running_models + available_models llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL)) index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
llm_model = st.selectbox("选择LLM模型", llm_model = st.selectbox("选择LLM模型",
llm_models, llm_models,
index, index,
@ -92,7 +113,7 @@ def dialogue_page(api: ApiRequest):
key="llm_model", key="llm_model",
) )
if (st.session_state.get("prev_llm_model") != llm_model if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api") and not api.get_model_config(llm_model).get("online_api")
and llm_model not in running_models): and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model") prev_model = st.session_state.get("prev_llm_model")
@ -114,7 +135,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "知识库问答": if dialogue_mode == "知识库问答":
with st.expander("知识库配置", True): with st.expander("知识库配置", True):
kb_list = api.list_knowledge_bases(no_remote_api=True) kb_list = api.list_knowledge_bases()
selected_kb = st.selectbox( selected_kb = st.selectbox(
"请选择知识库:", "请选择知识库:",
kb_list, kb_list,
@ -126,7 +147,7 @@ def dialogue_page(api: ApiRequest):
# chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
search_engine_list = list(SEARCH_ENGINES.keys()) search_engine_list = api.list_search_engines()
with st.expander("搜索引擎配置", True): with st.expander("搜索引擎配置", True):
search_engine = st.selectbox( search_engine = st.selectbox(
label="请选择搜索引擎", label="请选择搜索引擎",

View File

@ -1,12 +1,14 @@
# 该文件包含webui通用工具可以被不同的webui使用 # 该文件封装了对api.py的请求可以被不同的webui使用
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
from typing import * from typing import *
from pathlib import Path from pathlib import Path
# 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
from configs import ( from configs import (
EMBEDDING_MODEL, EMBEDDING_MODEL,
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL, LLM_MODEL,
HISTORY_LEN,
TEMPERATURE, TEMPERATURE,
SCORE_THRESHOLD, SCORE_THRESHOLD,
CHUNK_SIZE, CHUNK_SIZE,
@ -14,59 +16,44 @@ from configs import (
ZH_TITLE_ENHANCE, ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
FSCHAT_MODEL_WORKERS,
HTTPX_DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT,
logger, log_verbose, logger, log_verbose,
) )
import httpx import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn from server.chat.openai_chat import OpenAiChatMsgIn
from fastapi.responses import StreamingResponse
import contextlib import contextlib
import json import json
import os import os
from io import BytesIO from io import BytesIO
from server.utils import run_async, iter_over_async, set_httpx_config, api_address, get_httpx_client from server.utils import run_async, set_httpx_config, api_address, get_httpx_client
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from pprint import pprint from pprint import pprint
KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_config() set_httpx_config()
class ApiRequest: class ApiRequest:
''' '''
api.py调用的封装,主要实现: api.py调用的封装同步模式,简化api调用方式
1. 简化api调用方式
2. 实现无api调用(直接运行server.chat.*中的视图函数获取结果),无需启动api.py
''' '''
def __init__( def __init__(
self, self,
base_url: str = api_address(), base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT, timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly
): ):
self.base_url = base_url self.base_url = base_url
self.timeout = timeout self.timeout = timeout
self.no_remote_api = no_remote_api self._use_async = False
self._client = get_httpx_client() self._client = None
self._aclient = get_httpx_client(use_async=True)
if no_remote_api:
logger.warn("将来可能取消对no_remote_api的支持更新版本时请注意。")
def _parse_url(self, url: str) -> str: @property
if (not url.startswith("http") def client(self):
and self.base_url if self._client is None or self._client.is_closed:
): self._client = get_httpx_client(base_url=self.base_url,
part1 = self.base_url.strip(" /") use_async=self._use_async,
part2 = url.strip(" /") timeout=self.timeout)
return f"{part1}/{part2}" return self._client
else:
return url
def get( def get(
self, self,
@ -75,44 +62,19 @@ class ApiRequest:
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
return self._client.stream("GET", url, params=params, **kwargs) return self.client.stream("GET", url, params=params, **kwargs)
else: else:
return self._client.get(url, params=params, **kwargs) return self.client.get(url, params=params, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when get {url}: {e}" msg = f"error when get {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
retry -= 1 retry -= 1
async def aget(
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
if stream:
return await self._aclient.stream("GET", url, params=params, **kwargs)
else:
return await self._aclient.get(url, params=params, **kwargs)
except Exception as e:
msg = f"error when aget {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def post( def post(
self, self,
url: str, url: str,
@ -121,45 +83,19 @@ class ApiRequest:
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
return self._client.stream("POST", url, data=data, json=json, **kwargs) return self.client.stream("POST", url, data=data, json=json, **kwargs)
else: else:
return self._client.post(url, data=data, json=json, **kwargs) return self.client.post(url, data=data, json=json, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when post {url}: {e}" msg = f"error when post {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
retry -= 1 retry -= 1
async def apost(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
if stream:
return await self._client.stream("POST", url, data=data, json=json, **kwargs)
else:
return await self._client.post(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when apost {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def delete( def delete(
self, self,
url: str, url: str,
@ -168,65 +104,19 @@ class ApiRequest:
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
return self._client.stream("DELETE", url, data=data, json=json, **kwargs) return self.client.stream("DELETE", url, data=data, json=json, **kwargs)
else: else:
return self._client.delete(url, data=data, json=json, **kwargs) return self.client.delete(url, data=data, json=json, **kwargs)
except Exception as e: except Exception as e:
msg = f"error when delete {url}: {e}" msg = f"error when delete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
retry -= 1 retry -= 1
async def adelete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
if stream:
return await self._aclient.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return await self._aclient.delete(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when adelete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
'''
将api.py中视图函数返回的StreamingResponse转化为同步生成器
'''
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
try:
for chunk in iter_over_async(response.body_iterator, loop):
if as_json and chunk:
yield json.loads(chunk)
elif chunk.strip():
yield chunk
except Exception as e:
msg = f"error when run fastapi router: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
def _httpx_stream2generator( def _httpx_stream2generator(
self, self,
response: contextlib._GeneratorContextManager, response: contextlib._GeneratorContextManager,
@ -235,6 +125,39 @@ class ApiRequest:
''' '''
将httpx.stream返回的GeneratorContextManager转化为普通生成器 将httpx.stream返回的GeneratorContextManager转化为普通生成器
''' '''
async def ret_async(response, as_json):
try:
async with response as r:
async for chunk in r.aiter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
logger.error(msg)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见Wiki '5. 启动 API 服务或 Web UI')。({e}"
logger.error(msg)
yield {"code": 500, "msg": msg}
except Exception as e:
msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
def ret_sync(response, as_json):
try: try:
with response as r: with response as r:
for chunk in r.iter_text(None): for chunk in r.iter_text(None):
@ -255,10 +178,9 @@ class ApiRequest:
except httpx.ConnectError as e: except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})" msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
logger.error(msg) logger.error(msg)
logger.error(msg)
yield {"code": 500, "msg": msg} yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e: except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI')。({e}" msg = f"API通信超时请确认已启动FastChat与API服务详见Wiki '5. 启动 API 服务或 Web UI')。({e}"
logger.error(msg) logger.error(msg)
yield {"code": 500, "msg": msg} yield {"code": 500, "msg": msg}
except Exception as e: except Exception as e:
@ -267,6 +189,54 @@ class ApiRequest:
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg} yield {"code": 500, "msg": msg}
if self._use_async:
return ret_async(response, as_json)
else:
return ret_sync(response, as_json)
def _get_response_value(
self,
response: httpx.Response,
as_json: bool = False,
value_func: Callable = None,
):
'''
转换同步或异步请求返回的响应
`as_json`: 返回json
`value_func`: 用户可以自定义返回值该函数接受response或json
'''
def to_json(r):
try:
return r.json()
except Exception as e:
msg = "API未能返回正确的JSON。" + str(e)
if log_verbose:
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg}
if value_func is None:
value_func = (lambda r: r)
async def ret_async(response):
if as_json:
return value_func(to_json(await response))
else:
return value_func(await response)
if self._use_async:
return ret_async(response)
else:
if as_json:
return value_func(to_json(response))
else:
return value_func(response)
# 服务器信息
def get_server_configs(self, **kwargs):
response = self.post("/server/configs", **kwargs)
return self._get_response_value(response, lambda r: r.json())
# 对话相关操作 # 对话相关操作
def chat_fastchat( def chat_fastchat(
@ -275,15 +245,12 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, # TODO:根据message内容自动计算max_tokens max_tokens: int = 1024,
no_remote_api: bool = None,
**kwargs: Any, **kwargs: Any,
): ):
''' '''
对应api.py/chat/fastchat接口 对应api.py/chat/fastchat接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
msg = OpenAiChatMsgIn(**{ msg = OpenAiChatMsgIn(**{
"messages": messages, "messages": messages,
"stream": stream, "stream": stream,
@ -293,11 +260,6 @@ class ApiRequest:
**kwargs, **kwargs,
}) })
if no_remote_api:
from server.chat.openai_chat import openai_chat
response = run_async(openai_chat(msg))
return self._fastapi_stream2generator(response)
else:
data = msg.dict(exclude_unset=True, exclude_none=True) data = msg.dict(exclude_unset=True, exclude_none=True)
print(f"received input message:") print(f"received input message:")
pprint(data) pprint(data)
@ -318,14 +280,11 @@ class ApiRequest:
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = 1024,
prompt_name: str = "llm_chat", prompt_name: str = "llm_chat",
no_remote_api: bool = None, **kwargs,
): ):
''' '''
对应api.py/chat/chat接口 对应api.py/chat/chat接口 #TODO: 考虑是否返回json
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"query": query, "query": query,
"history": history, "history": history,
@ -339,12 +298,7 @@ class ApiRequest:
print(f"received input message:") print(f"received input message:")
pprint(data) pprint(data)
if no_remote_api: response = self.post("/chat/chat", json=data, stream=True, **kwargs)
from server.chat.chat import chat
response = run_async(chat(**data))
return self._fastapi_stream2generator(response)
else:
response = self.post("/chat/chat", json=data, stream=True)
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response)
def agent_chat( def agent_chat(
@ -355,14 +309,10 @@ class ApiRequest:
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = 1024,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/chat/agent_chat 接口 对应api.py/chat/agent_chat 接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"query": query, "query": query,
"history": history, "history": history,
@ -375,11 +325,6 @@ class ApiRequest:
print(f"received input message:") print(f"received input message:")
pprint(data) pprint(data)
if no_remote_api:
from server.chat.agent_chat import agent_chat
response = run_async(agent_chat(**data))
return self._fastapi_stream2generator(response)
else:
response = self.post("/chat/agent_chat", json=data, stream=True) response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response)
@ -395,14 +340,10 @@ class ApiRequest:
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = 1024,
prompt_name: str = "knowledge_base_chat", prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None,
): ):
''' '''
对应api.py/chat/knowledge_base_chat接口 对应api.py/chat/knowledge_base_chat接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"query": query, "query": query,
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
@ -413,18 +354,12 @@ class ApiRequest:
"model_name": model, "model_name": model,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"local_doc_url": no_remote_api,
"prompt_name": prompt_name, "prompt_name": prompt_name,
} }
print(f"received input message:") print(f"received input message:")
pprint(data) pprint(data)
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = run_async(knowledge_base_chat(**data))
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post( response = self.post(
"/chat/knowledge_base_chat", "/chat/knowledge_base_chat",
json=data, json=data,
@ -443,14 +378,10 @@ class ApiRequest:
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = 1024,
prompt_name: str = "knowledge_base_chat", prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None,
): ):
''' '''
对应api.py/chat/search_engine_chat接口 对应api.py/chat/search_engine_chat接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"query": query, "query": query,
"search_engine_name": search_engine_name, "search_engine_name": search_engine_name,
@ -466,11 +397,6 @@ class ApiRequest:
print(f"received input message:") print(f"received input message:")
pprint(data) pprint(data)
if no_remote_api:
from server.chat.search_engine_chat import search_engine_chat
response = run_async(search_engine_chat(**data))
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post( response = self.post(
"/chat/search_engine_chat", "/chat/search_engine_chat",
json=data, json=data,
@ -480,116 +406,65 @@ class ApiRequest:
# 知识库相关操作 # 知识库相关操作
def _check_httpx_json_response(
self,
response: httpx.Response,
errorMsg: str = f"无法连接API服务器请确认 api.py 已正常启动。",
) -> Dict:
'''
check whether httpx returns correct data with normal Response.
error in api with streaming support was checked in _httpx_stream2enerator
'''
try:
return response.json()
except Exception as e:
msg = "API未能返回正确的JSON。" + (errorMsg or str(e))
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg}
def list_knowledge_bases( def list_knowledge_bases(
self, self,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/list_knowledge_bases接口 对应api.py/knowledge_base/list_knowledge_bases接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.knowledge_base.kb_api import list_kbs
response = list_kbs()
return response.data
else:
response = self.get("/knowledge_base/list_knowledge_bases") response = self.get("/knowledge_base/list_knowledge_bases")
data = self._check_httpx_json_response(response) return self._get_response_value(response,
return data.get("data", []) as_json=True,
value_func=lambda r: r.get("data", []))
def create_knowledge_base( def create_knowledge_base(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
vector_store_type: str = "faiss", vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/create_knowledge_base接口 对应api.py/knowledge_base/create_knowledge_base接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"vector_store_type": vector_store_type, "vector_store_type": vector_store_type,
"embed_model": embed_model, "embed_model": embed_model,
} }
if no_remote_api:
from server.knowledge_base.kb_api import create_kb
response = create_kb(**data)
return response.dict()
else:
response = self.post( response = self.post(
"/knowledge_base/create_knowledge_base", "/knowledge_base/create_knowledge_base",
json=data, json=data,
) )
return self._check_httpx_json_response(response) return self._get_response_value(response, as_json=True)
def delete_knowledge_base( def delete_knowledge_base(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/delete_knowledge_base接口 对应api.py/knowledge_base/delete_knowledge_base接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.knowledge_base.kb_api import delete_kb
response = delete_kb(knowledge_base_name)
return response.dict()
else:
response = self.post( response = self.post(
"/knowledge_base/delete_knowledge_base", "/knowledge_base/delete_knowledge_base",
json=f"{knowledge_base_name}", json=f"{knowledge_base_name}",
) )
return self._check_httpx_json_response(response) return self._get_response_value(response, as_json=True)
def list_kb_docs( def list_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/list_files接口 对应api.py/knowledge_base/list_files接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.knowledge_base.kb_doc_api import list_files
response = list_files(knowledge_base_name)
return response.data
else:
response = self.get( response = self.get(
"/knowledge_base/list_files", "/knowledge_base/list_files",
params={"knowledge_base_name": knowledge_base_name} params={"knowledge_base_name": knowledge_base_name}
) )
data = self._check_httpx_json_response(response) return self._get_response_value(response,
return data.get("data", []) as_json=True,
value_func=lambda r: r.get("data", []))
def search_kb_docs( def search_kb_docs(
self, self,
@ -597,14 +472,10 @@ class ApiRequest:
knowledge_base_name: str, knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD, score_threshold: int = SCORE_THRESHOLD,
no_remote_api: bool = None,
) -> List: ) -> List:
''' '''
对应api.py/knowledge_base/search_docs接口 对应api.py/knowledge_base/search_docs接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"query": query, "query": query,
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
@ -612,16 +483,11 @@ class ApiRequest:
"score_threshold": score_threshold, "score_threshold": score_threshold,
} }
if no_remote_api:
from server.knowledge_base.kb_doc_api import search_docs
return search_docs(**data)
else:
response = self.post( response = self.post(
"/knowledge_base/search_docs", "/knowledge_base/search_docs",
json=data, json=data,
) )
data = self._check_httpx_json_response(response) return self._get_response_value(response, as_json=True)
return data
def upload_kb_docs( def upload_kb_docs(
self, self,
@ -634,14 +500,10 @@ class ApiRequest:
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {}, docs: Dict = {},
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/upload_docs接口 对应api.py/knowledge_base/upload_docs接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
def convert_file(file, filename=None): def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes if isinstance(file, bytes): # raw bytes
file = BytesIO(file) file = BytesIO(file)
@ -664,21 +526,6 @@ class ApiRequest:
"not_refresh_vs_cache": not_refresh_vs_cache, "not_refresh_vs_cache": not_refresh_vs_cache,
} }
if no_remote_api:
from server.knowledge_base.kb_doc_api import upload_docs
from fastapi import UploadFile
from tempfile import SpooledTemporaryFile
upload_files = []
for filename, file in files:
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
upload_files.append(UploadFile(file=temp_file, filename=filename))
response = upload_docs(upload_files, **data)
return response.dict()
else:
if isinstance(data["docs"], dict): if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False) data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post( response = self.post(
@ -686,7 +533,7 @@ class ApiRequest:
data=data, data=data,
files=[("files", (filename, file)) for filename, file in files], files=[("files", (filename, file)) for filename, file in files],
) )
return self._check_httpx_json_response(response) return self._get_response_value(response, as_json=True)
def delete_kb_docs( def delete_kb_docs(
self, self,
@ -694,14 +541,10 @@ class ApiRequest:
file_names: List[str], file_names: List[str],
delete_content: bool = False, delete_content: bool = False,
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/delete_docs接口 对应api.py/knowledge_base/delete_docs接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"file_names": file_names, "file_names": file_names,
@ -709,16 +552,11 @@ class ApiRequest:
"not_refresh_vs_cache": not_refresh_vs_cache, "not_refresh_vs_cache": not_refresh_vs_cache,
} }
if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_docs
response = delete_docs(**data)
return response.dict()
else:
response = self.post( response = self.post(
"/knowledge_base/delete_docs", "/knowledge_base/delete_docs",
json=data, json=data,
) )
return self._check_httpx_json_response(response) return self._get_response_value(response, as_json=True)
def update_kb_docs( def update_kb_docs(
self, self,
@ -730,14 +568,10 @@ class ApiRequest:
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {}, docs: Dict = {},
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/update_docs接口 对应api.py/knowledge_base/update_docs接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"file_names": file_names, "file_names": file_names,
@ -748,18 +582,15 @@ class ApiRequest:
"docs": docs, "docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache, "not_refresh_vs_cache": not_refresh_vs_cache,
} }
if no_remote_api:
from server.knowledge_base.kb_doc_api import update_docs
response = update_docs(**data)
return response.dict()
else:
if isinstance(data["docs"], dict): if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False) data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post( response = self.post(
"/knowledge_base/update_docs", "/knowledge_base/update_docs",
json=data, json=data,
) )
return self._check_httpx_json_response(response) return self._get_response_value(response, as_json=True)
def recreate_vector_store( def recreate_vector_store(
self, self,
@ -770,14 +601,10 @@ class ApiRequest:
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/recreate_vector_store接口 对应api.py/knowledge_base/recreate_vector_store接口
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"allow_empty_kb": allow_empty_kb, "allow_empty_kb": allow_empty_kb,
@ -788,11 +615,6 @@ class ApiRequest:
"zh_title_enhance": zh_title_enhance, "zh_title_enhance": zh_title_enhance,
} }
if no_remote_api:
from server.knowledge_base.kb_doc_api import recreate_vector_store
response = recreate_vector_store(**data)
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post( response = self.post(
"/knowledge_base/recreate_vector_store", "/knowledge_base/recreate_vector_store",
json=data, json=data,
@ -805,88 +627,89 @@ class ApiRequest:
def list_running_models( def list_running_models(
self, self,
controller_address: str = None, controller_address: str = None,
no_remote_api: bool = None,
): ):
''' '''
获取Fastchat中正运行的模型列表 获取Fastchat中正运行的模型列表
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"controller_address": controller_address, "controller_address": controller_address,
} }
if no_remote_api:
from server.llm_api import list_running_models response = self.post(
return list_running_models(**data).data
else:
r = self.post(
"/llm_model/list_running_models", "/llm_model/list_running_models",
json=data, json=data,
) )
return r.json().get("data", []) return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]: def list_config_models(self) -> Dict[str, List[str]]:
''' '''
获取configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...} 获取服务器configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...}
如果no_remote_api=True, 从运行ApiRequest的机器上获取否则从运行api.py的机器上获取
''' '''
if no_remote_api is None: response = self.post(
no_remote_api = self.no_remote_api
if no_remote_api:
from server.llm_api import list_config_models
return list_config_models().data
else:
r = self.post(
"/llm_model/list_config_models", "/llm_model/list_config_models",
) )
return r.json().get("data", {}) return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
def get_model_config(
self,
model_name: str,
) -> Dict:
'''
获取服务器上模型配置
'''
data={
"model_name": model_name,
}
response = self.post(
"/llm_model/get_model_config",
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
def list_search_engines(self) -> List[str]:
'''
获取服务器支持的搜索引擎
'''
response = self.post(
"/server/list_search_engines",
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
def stop_llm_model( def stop_llm_model(
self, self,
model_name: str, model_name: str,
controller_address: str = None, controller_address: str = None,
no_remote_api: bool = None,
): ):
''' '''
停止某个LLM模型 停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉 注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = { data = {
"model_name": model_name, "model_name": model_name,
"controller_address": controller_address, "controller_address": controller_address,
} }
if no_remote_api: response = self.post(
from server.llm_api import stop_llm_model
return stop_llm_model(**data).dict()
else:
r = self.post(
"/llm_model/stop", "/llm_model/stop",
json=data, json=data,
) )
return r.json() return self._get_response_value(response, as_json=True)
def change_llm_model( def change_llm_model(
self, self,
model_name: str, model_name: str,
new_model_name: str, new_model_name: str,
controller_address: str = None, controller_address: str = None,
no_remote_api: bool = None,
): ):
''' '''
向fastchat controller请求切换LLM模型 向fastchat controller请求切换LLM模型
''' '''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if not model_name or not new_model_name: if not model_name or not new_model_name:
return return {
"code": 500,
"msg": f"未指定模型名称"
}
def ret_sync():
running_models = self.list_running_models() running_models = self.list_running_models()
if new_model_name == model_name or new_model_name in running_models: if new_model_name == model_name or new_model_name in running_models:
return { return {
@ -901,7 +724,7 @@ class ApiRequest:
} }
config_models = self.list_config_models() config_models = self.list_config_models()
if new_model_name not in config_models.get("local", []): if new_model_name not in config_models.get("local", {}):
return { return {
"code": 500, "code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
@ -913,16 +736,55 @@ class ApiRequest:
"controller_address": controller_address, "controller_address": controller_address,
} }
if no_remote_api: response = self.post(
from server.llm_api import change_llm_model
return change_llm_model(**data).dict()
else:
r = self.post(
"/llm_model/change", "/llm_model/change",
json=data, json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
) )
return r.json() return self._get_response_value(response, as_json=True)
async def ret_async():
running_models = await self.list_running_models()
if new_model_name == model_name or new_model_name in running_models:
return {
"code": 200,
"msg": "无需切换"
}
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = await self.list_config_models()
if new_model_name not in config_models.get("local", {}):
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
response = self.post(
"/llm_model/change",
json=data,
)
return self._get_response_value(response, as_json=True)
if self._use_async:
return ret_async()
else:
return ret_sync()
class AsyncApiRequest(ApiRequest):
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
super().__init__(base_url, timeout)
self._use_async = True
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
@ -950,7 +812,8 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
if __name__ == "__main__": if __name__ == "__main__":
api = ApiRequest(no_remote_api=True) api = ApiRequest()
aapi = AsyncApiRequest()
# print(api.chat_fastchat( # print(api.chat_fastchat(
# messages=[{"role": "user", "content": "hello"}] # messages=[{"role": "user", "content": "hello"}]