mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 07:53:29 +08:00
实现Api和WEBUI的前后端分离 (#1772)
* update ApiRequest: 删除no_remote_api本地调用模式;支持同步/异步调用 * 实现API和WEBUI的分离: - API运行服务器上的配置通过/llm_model/get_model_config、/server/configs接口提供,WEBUI运行机器上的配置项仅作为代码内部默认值使用 - 服务器可用的搜索引擎通过/server/list_search_engines提供 - WEBUI可选LLM列表中只列出在FSCHAT_MODEL_WORKERS中配置的模型 - 修改WEBUI中默认LLM_MODEL获取方式,改为从api端读取 - 删除knowledge_base_chat中`local_doc_url`参数 其它修改: - 删除多余的kb_config.py.exmaple(名称错误) - server_config中默认关闭vllm - server_config中默认注释除智谱AI之外的在线模型 - 修改requests从系统获取的代理,避免model worker注册错误 * 修正: - api.list_config_models返回模型原始配置 - api.list_config_models和api.get_model_config中过滤online api模型的敏感信息 - 将GPT等直接访问的模型列入WEBUI可选模型列表 其它: - 指定langchain==0.3.313, fschat==0.2.30, langchain-experimental==0.0.30
This commit is contained in:
parent
94977c7ab1
commit
9ce328fea9
@ -1,99 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
|
||||||
DEFAULT_VS_TYPE = "faiss"
|
|
||||||
|
|
||||||
# 缓存向量库数量(针对FAISS)
|
|
||||||
CACHED_VS_NUM = 1
|
|
||||||
|
|
||||||
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
|
||||||
CHUNK_SIZE = 250
|
|
||||||
|
|
||||||
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
|
|
||||||
OVERLAP_SIZE = 50
|
|
||||||
|
|
||||||
# 知识库匹配向量数量
|
|
||||||
VECTOR_SEARCH_TOP_K = 3
|
|
||||||
|
|
||||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
|
||||||
SCORE_THRESHOLD = 1
|
|
||||||
|
|
||||||
# 搜索引擎匹配结题数量
|
|
||||||
SEARCH_ENGINE_TOP_K = 3
|
|
||||||
|
|
||||||
|
|
||||||
# Bing 搜索必备变量
|
|
||||||
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
|
|
||||||
# 具体申请方式请见
|
|
||||||
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
|
|
||||||
# 使用python创建bing api 搜索实例详见:
|
|
||||||
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
|
|
||||||
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
|
||||||
# 注意不是bing Webmaster Tools的api key,
|
|
||||||
|
|
||||||
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
|
|
||||||
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
|
||||||
BING_SUBSCRIPTION_KEY = ""
|
|
||||||
|
|
||||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
|
||||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
|
||||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
|
||||||
ZH_TITLE_ENHANCE = False
|
|
||||||
|
|
||||||
|
|
||||||
# 通常情况下不需要更改以下内容
|
|
||||||
|
|
||||||
# 知识库默认存储路径
|
|
||||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
|
||||||
if not os.path.exists(KB_ROOT_PATH):
|
|
||||||
os.mkdir(KB_ROOT_PATH)
|
|
||||||
|
|
||||||
# 数据库默认存储路径。
|
|
||||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
|
||||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
|
||||||
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
|
||||||
|
|
||||||
# 可选向量库类型及对应配置
|
|
||||||
kbs_config = {
|
|
||||||
"faiss": {
|
|
||||||
},
|
|
||||||
"milvus": {
|
|
||||||
"host": "127.0.0.1",
|
|
||||||
"port": "19530",
|
|
||||||
"user": "",
|
|
||||||
"password": "",
|
|
||||||
"secure": False,
|
|
||||||
},
|
|
||||||
"pg": {
|
|
||||||
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
|
|
||||||
text_splitter_dict = {
|
|
||||||
"ChineseRecursiveTextSplitter": {
|
|
||||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
|
||||||
"tokenizer_name_or_path": "gpt2",
|
|
||||||
},
|
|
||||||
"SpacyTextSplitter": {
|
|
||||||
"source": "huggingface",
|
|
||||||
"tokenizer_name_or_path": "",
|
|
||||||
},
|
|
||||||
"RecursiveCharacterTextSplitter": {
|
|
||||||
"source": "tiktoken",
|
|
||||||
"tokenizer_name_or_path": "cl100k_base",
|
|
||||||
},
|
|
||||||
"MarkdownHeaderTextSplitter": {
|
|
||||||
"headers_to_split_on":
|
|
||||||
[
|
|
||||||
("#", "head1"),
|
|
||||||
("##", "head2"),
|
|
||||||
("###", "head3"),
|
|
||||||
("####", "head4"),
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# TEXT_SPLITTER 名称
|
|
||||||
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
|
|
||||||
@ -32,6 +32,7 @@ FSCHAT_OPENAI_API = {
|
|||||||
# fastchat model_worker server
|
# fastchat model_worker server
|
||||||
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
|
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
|
||||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||||
|
# 必须在这里添加的模型才会出现在WEBUI中可选模型列表里(LLM_MODEL会自动添加)
|
||||||
FSCHAT_MODEL_WORKERS = {
|
FSCHAT_MODEL_WORKERS = {
|
||||||
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
|
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
|
||||||
"default": {
|
"default": {
|
||||||
@ -39,7 +40,8 @@ FSCHAT_MODEL_WORKERS = {
|
|||||||
"port": 20002,
|
"port": 20002,
|
||||||
"device": LLM_DEVICE,
|
"device": LLM_DEVICE,
|
||||||
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题,参见doc/FAQ
|
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题,参见doc/FAQ
|
||||||
"infer_turbo": "vllm" if sys.platform.startswith("linux") else False,
|
# vllm对一些模型支持还不成熟,暂时默认关闭
|
||||||
|
"infer_turbo": False,
|
||||||
|
|
||||||
# model_worker多卡加载需要配置的参数
|
# model_worker多卡加载需要配置的参数
|
||||||
# "gpus": None, # 使用的GPU,以str的格式指定,如"0,1",如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定
|
# "gpus": None, # 使用的GPU,以str的格式指定,如"0,1",如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定
|
||||||
@ -97,24 +99,24 @@ FSCHAT_MODEL_WORKERS = {
|
|||||||
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口
|
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口
|
||||||
"port": 21001,
|
"port": 21001,
|
||||||
},
|
},
|
||||||
"minimax-api": {
|
# "minimax-api": {
|
||||||
"port": 21002,
|
# "port": 21002,
|
||||||
},
|
# },
|
||||||
"xinghuo-api": {
|
# "xinghuo-api": {
|
||||||
"port": 21003,
|
# "port": 21003,
|
||||||
},
|
# },
|
||||||
"qianfan-api": {
|
# "qianfan-api": {
|
||||||
"port": 21004,
|
# "port": 21004,
|
||||||
},
|
# },
|
||||||
"fangzhou-api": {
|
# "fangzhou-api": {
|
||||||
"port": 21005,
|
# "port": 21005,
|
||||||
},
|
# },
|
||||||
"qwen-api": {
|
# "qwen-api": {
|
||||||
"port": 21006,
|
# "port": 21006,
|
||||||
},
|
# },
|
||||||
"baichuan-api": {
|
# "baichuan-api": {
|
||||||
"port": 21007,
|
# "port": 21007,
|
||||||
},
|
# },
|
||||||
}
|
}
|
||||||
|
|
||||||
# fastchat multi model worker server
|
# fastchat multi model worker server
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
langchain>=0.0.310
|
langchain==0.0.313
|
||||||
fschat[model_worker]>=0.2.30
|
langchain-experimental==0.0.30
|
||||||
|
fschat[model_worker]==0.2.30
|
||||||
openai
|
openai
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
transformers>=4.34
|
transformers>=4.34
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
langchain>=0.0.310
|
langchain==0.0.313
|
||||||
fschat[model_worker]>=0.2.30
|
langchain-experimental==0.0.30
|
||||||
|
fschat[model_worker]==0.2.30
|
||||||
openai
|
openai
|
||||||
sentence_transformers>=2.2.2
|
sentence_transformers>=2.2.2
|
||||||
transformers>=4.34
|
transformers>=4.34
|
||||||
|
|||||||
@ -17,8 +17,10 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
|||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
update_docs, download_doc, recreate_vector_store,
|
update_docs, download_doc, recreate_vector_store,
|
||||||
search_docs, DocumentWithScore)
|
search_docs, DocumentWithScore)
|
||||||
from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model
|
from server.llm_api import (list_running_models, list_config_models,
|
||||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
change_llm_model, stop_llm_model,
|
||||||
|
get_model_config, list_search_engines)
|
||||||
|
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
@ -139,6 +141,11 @@ def create_app():
|
|||||||
summary="列出configs已配置的模型",
|
summary="列出configs已配置的模型",
|
||||||
)(list_config_models)
|
)(list_config_models)
|
||||||
|
|
||||||
|
app.post("/llm_model/get_model_config",
|
||||||
|
tags=["LLM Model Management"],
|
||||||
|
summary="获取模型配置(合并后)",
|
||||||
|
)(get_model_config)
|
||||||
|
|
||||||
app.post("/llm_model/stop",
|
app.post("/llm_model/stop",
|
||||||
tags=["LLM Model Management"],
|
tags=["LLM Model Management"],
|
||||||
summary="停止指定的LLM模型(Model Worker)",
|
summary="停止指定的LLM模型(Model Worker)",
|
||||||
@ -149,6 +156,17 @@ def create_app():
|
|||||||
summary="切换指定的LLM模型(Model Worker)",
|
summary="切换指定的LLM模型(Model Worker)",
|
||||||
)(change_llm_model)
|
)(change_llm_model)
|
||||||
|
|
||||||
|
# 服务器相关接口
|
||||||
|
app.post("/server/configs",
|
||||||
|
tags=["Server State"],
|
||||||
|
summary="获取服务器原始配置信息",
|
||||||
|
)(get_server_configs)
|
||||||
|
|
||||||
|
app.post("/server/list_search_engines",
|
||||||
|
tags=["Server State"],
|
||||||
|
summary="获取服务器支持的搜索引擎",
|
||||||
|
)(list_search_engines)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
|
||||||
request: Request = None,
|
request: Request = None,
|
||||||
):
|
):
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
@ -74,9 +73,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||||||
source_documents = []
|
source_documents = []
|
||||||
for inum, doc in enumerate(docs):
|
for inum, doc in enumerate(docs):
|
||||||
filename = os.path.split(doc.metadata["source"])[-1]
|
filename = os.path.split(doc.metadata["source"])[-1]
|
||||||
if local_doc_url:
|
|
||||||
url = "file://" + doc.metadata["source"]
|
|
||||||
else:
|
|
||||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||||
url = f"{request.base_url}knowledge_base/download_doc?" + parameters
|
url = f"{request.base_url}knowledge_base/download_doc?" + parameters
|
||||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||||
from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client
|
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||||
|
get_httpx_client, get_model_worker_config)
|
||||||
|
|
||||||
|
|
||||||
def list_running_models(
|
def list_running_models(
|
||||||
@ -9,19 +9,21 @@ def list_running_models(
|
|||||||
placeholder: str = Body(None, description="该参数未使用,占位用"),
|
placeholder: str = Body(None, description="该参数未使用,占位用"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
'''
|
'''
|
||||||
从fastchat controller获取已加载模型列表
|
从fastchat controller获取已加载模型列表及其配置项
|
||||||
'''
|
'''
|
||||||
try:
|
try:
|
||||||
controller_address = controller_address or fschat_controller_address()
|
controller_address = controller_address or fschat_controller_address()
|
||||||
with get_httpx_client() as client:
|
with get_httpx_client() as client:
|
||||||
r = client.post(controller_address + "/list_models")
|
r = client.post(controller_address + "/list_models")
|
||||||
return BaseResponse(data=r.json()["models"])
|
models = r.json()["models"]
|
||||||
|
data = {m: get_model_worker_config(m) for m in models}
|
||||||
|
return BaseResponse(data=data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'{e.__class__.__name__}: {e}',
|
logger.error(f'{e.__class__.__name__}: {e}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
return BaseResponse(
|
return BaseResponse(
|
||||||
code=500,
|
code=500,
|
||||||
data=[],
|
data={},
|
||||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +31,38 @@ def list_config_models() -> BaseResponse:
|
|||||||
'''
|
'''
|
||||||
从本地获取configs中配置的模型列表
|
从本地获取configs中配置的模型列表
|
||||||
'''
|
'''
|
||||||
return BaseResponse(data=list_llm_models())
|
configs = list_config_llm_models()
|
||||||
|
|
||||||
|
# 删除ONLINE_MODEL配置中的敏感信息
|
||||||
|
for config in configs["online"].values():
|
||||||
|
del_keys = set(["worker_class"])
|
||||||
|
for k in config:
|
||||||
|
if "key" in k.lower() or "secret" in k.lower():
|
||||||
|
del_keys.add(k)
|
||||||
|
for k in del_keys:
|
||||||
|
config.pop(k, None)
|
||||||
|
|
||||||
|
return BaseResponse(data=configs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_config(
|
||||||
|
model_name: str = Body(description="配置中LLM模型的名称"),
|
||||||
|
placeholder: str = Body(description="占位用,无实际效果")
|
||||||
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
获取LLM模型配置项(合并后的)
|
||||||
|
'''
|
||||||
|
config = get_model_worker_config(model_name=model_name)
|
||||||
|
|
||||||
|
# 删除ONLINE_MODEL配置中的敏感信息
|
||||||
|
del_keys = set(["worker_class"])
|
||||||
|
for k in config:
|
||||||
|
if "key" in k.lower() or "secret" in k.lower():
|
||||||
|
del_keys.add(k)
|
||||||
|
for k in del_keys:
|
||||||
|
config.pop(k, None)
|
||||||
|
|
||||||
|
return BaseResponse(data=config)
|
||||||
|
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
@ -79,3 +112,9 @@ def change_llm_model(
|
|||||||
return BaseResponse(
|
return BaseResponse(
|
||||||
code=500,
|
code=500,
|
||||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def list_search_engines() -> BaseResponse:
|
||||||
|
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||||
|
|
||||||
|
return BaseResponse(data=list(SEARCH_ENGINES))
|
||||||
|
|||||||
@ -258,17 +258,18 @@ def list_embed_models() -> List[str]:
|
|||||||
return list(MODEL_PATH["embed_model"])
|
return list(MODEL_PATH["embed_model"])
|
||||||
|
|
||||||
|
|
||||||
def list_llm_models() -> Dict[str, List[str]]:
|
def list_config_llm_models() -> Dict[str, Dict]:
|
||||||
'''
|
'''
|
||||||
get names of configured llm models with different types.
|
get configured llm models with different types.
|
||||||
return [(model_name, config_type), ...]
|
return [(model_name, config_type), ...]
|
||||||
'''
|
'''
|
||||||
workers = list(FSCHAT_MODEL_WORKERS)
|
workers = list(FSCHAT_MODEL_WORKERS)
|
||||||
if "default" in workers:
|
if LLM_MODEL not in workers:
|
||||||
workers.remove("default")
|
workers.insert(0, LLM_MODEL)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"local": list(MODEL_PATH["llm_model"]),
|
"local": MODEL_PATH["llm_model"],
|
||||||
"online": list(ONLINE_LLM_MODEL),
|
"online": ONLINE_LLM_MODEL,
|
||||||
"worker": workers,
|
"worker": workers,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,7 +307,7 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
加载model worker的配置项。
|
加载model worker的配置项。
|
||||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||||
'''
|
'''
|
||||||
from configs.model_config import ONLINE_LLM_MODEL
|
from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
|
||||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||||
from server import model_workers
|
from server import model_workers
|
||||||
|
|
||||||
@ -324,7 +325,8 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
|
# 本地模型
|
||||||
|
if model_name in MODEL_PATH["llm_model"]:
|
||||||
config["model_path"] = get_model_path(model_name)
|
config["model_path"] = get_model_path(model_name)
|
||||||
config["device"] = llm_device(config.get("device"))
|
config["device"] = llm_device(config.get("device"))
|
||||||
return config
|
return config
|
||||||
@ -449,11 +451,11 @@ def set_httpx_config(
|
|||||||
|
|
||||||
# TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。
|
# TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。
|
||||||
# patch requests to use custom proxies instead of system settings
|
# patch requests to use custom proxies instead of system settings
|
||||||
# def _get_proxies():
|
def _get_proxies():
|
||||||
# return {}
|
return proxies
|
||||||
|
|
||||||
# import urllib.request
|
import urllib.request
|
||||||
# urllib.request.getproxies = _get_proxies
|
urllib.request.getproxies = _get_proxies
|
||||||
|
|
||||||
|
|
||||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||||
@ -557,3 +559,35 @@ def get_httpx_client(
|
|||||||
return httpx.AsyncClient(**kwargs)
|
return httpx.AsyncClient(**kwargs)
|
||||||
else:
|
else:
|
||||||
return httpx.Client(**kwargs)
|
return httpx.Client(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_server_configs() -> Dict:
|
||||||
|
'''
|
||||||
|
获取configs中的原始配置项,供前端使用
|
||||||
|
'''
|
||||||
|
from configs.kb_config import (
|
||||||
|
DEFAULT_VS_TYPE,
|
||||||
|
CHUNK_SIZE,
|
||||||
|
OVERLAP_SIZE,
|
||||||
|
SCORE_THRESHOLD,
|
||||||
|
VECTOR_SEARCH_TOP_K,
|
||||||
|
SEARCH_ENGINE_TOP_K,
|
||||||
|
ZH_TITLE_ENHANCE,
|
||||||
|
text_splitter_dict,
|
||||||
|
TEXT_SPLITTER_NAME,
|
||||||
|
)
|
||||||
|
from configs.model_config import (
|
||||||
|
LLM_MODEL,
|
||||||
|
EMBEDDING_MODEL,
|
||||||
|
HISTORY_LEN,
|
||||||
|
TEMPERATURE,
|
||||||
|
)
|
||||||
|
from configs.prompt_config import PROMPT_TEMPLATES
|
||||||
|
|
||||||
|
_custom = {
|
||||||
|
"controller_address": fschat_controller_address(),
|
||||||
|
"openai_api_address": fschat_openai_api_address(),
|
||||||
|
"api_address": api_address(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from pprint import pprint
|
|||||||
|
|
||||||
|
|
||||||
api_base_url = api_address()
|
api_base_url = api_address()
|
||||||
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False)
|
api: ApiRequest = ApiRequest(api_base_url)
|
||||||
|
|
||||||
|
|
||||||
kb = "kb_for_api_test"
|
kb = "kb_for_api_test"
|
||||||
|
|||||||
@ -32,7 +32,7 @@ def get_running_models(api="/llm_model/list_models"):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def test_running_models(api="/llm_model/list_models"):
|
def test_running_models(api="/llm_model/list_running_models"):
|
||||||
url = api_base_url + api
|
url = api_base_url + api
|
||||||
r = requests.post(url)
|
r = requests.post(url)
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
@ -48,7 +48,7 @@ def test_running_models(api="/llm_model/list_models"):
|
|||||||
# r = requests.post(url, json={""})
|
# r = requests.post(url, json={""})
|
||||||
|
|
||||||
|
|
||||||
def test_change_model(api="/llm_model/change"):
|
def test_change_model(api="/llm_model/change_model"):
|
||||||
url = api_base_url + api
|
url = api_base_url + api
|
||||||
|
|
||||||
running_models = get_running_models()
|
running_models = get_running_models()
|
||||||
|
|||||||
3
webui.py
3
webui.py
@ -22,9 +22,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not chat_box.chat_inited:
|
if not chat_box.chat_inited:
|
||||||
|
running_models = api.list_running_models()
|
||||||
st.toast(
|
st.toast(
|
||||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
||||||
f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了."
|
f"当前运行中的模型`{running_models}`, 您可以开始提问了."
|
||||||
)
|
)
|
||||||
|
|
||||||
pages = {
|
pages = {
|
||||||
|
|||||||
@ -2,11 +2,11 @@ import streamlit as st
|
|||||||
from webui_pages.utils import *
|
from webui_pages.utils import *
|
||||||
from streamlit_chatbox import *
|
from streamlit_chatbox import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
|
||||||
import os
|
import os
|
||||||
from configs import LLM_MODEL, TEMPERATURE
|
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
||||||
from server.utils import get_model_worker_config
|
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
"img",
|
"img",
|
||||||
@ -15,9 +15,6 @@ chat_box = ChatBox(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||||
'''
|
'''
|
||||||
返回消息历史。
|
返回消息历史。
|
||||||
@ -38,6 +35,26 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
|||||||
return chat_box.filter_history(history_len=history_len, filter=filter)
|
return chat_box.filter_history(history_len=history_len, filter=filter)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_llm_model(api: ApiRequest) -> (str, bool):
|
||||||
|
'''
|
||||||
|
从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回
|
||||||
|
返回类型为(model_name, is_local_model)
|
||||||
|
'''
|
||||||
|
running_models = api.list_running_models()
|
||||||
|
|
||||||
|
if not running_models:
|
||||||
|
return "", False
|
||||||
|
|
||||||
|
if LLM_MODEL in running_models:
|
||||||
|
return LLM_MODEL, True
|
||||||
|
|
||||||
|
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||||
|
if local_models:
|
||||||
|
return local_models[0], True
|
||||||
|
|
||||||
|
return running_models[0], False
|
||||||
|
|
||||||
|
|
||||||
def dialogue_page(api: ApiRequest):
|
def dialogue_page(api: ApiRequest):
|
||||||
chat_box.init_session()
|
chat_box.init_session()
|
||||||
|
|
||||||
@ -51,7 +68,6 @@ def dialogue_page(api: ApiRequest):
|
|||||||
if cur_kb:
|
if cur_kb:
|
||||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
text = f"{text} 当前知识库: `{cur_kb}`。"
|
||||||
st.toast(text)
|
st.toast(text)
|
||||||
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
|
|
||||||
|
|
||||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||||
["LLM 对话",
|
["LLM 对话",
|
||||||
@ -65,7 +81,7 @@ def dialogue_page(api: ApiRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def on_llm_change():
|
def on_llm_change():
|
||||||
config = get_model_worker_config(llm_model)
|
config = api.get_model_config(llm_model)
|
||||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
||||||
@ -75,15 +91,20 @@ def dialogue_page(api: ApiRequest):
|
|||||||
return f"{x} (Running)"
|
return f"{x} (Running)"
|
||||||
return x
|
return x
|
||||||
|
|
||||||
running_models = api.list_running_models()
|
running_models = list(api.list_running_models())
|
||||||
available_models = []
|
available_models = []
|
||||||
config_models = api.list_config_models()
|
config_models = api.list_config_models()
|
||||||
for models in config_models.values():
|
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
||||||
for m in models:
|
for m in worker_models:
|
||||||
if m not in running_models:
|
if m not in running_models and m != "default":
|
||||||
available_models.append(m)
|
available_models.append(m)
|
||||||
|
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT)
|
||||||
|
if not v.get("provider") and k not in running_models:
|
||||||
|
print(k, v)
|
||||||
|
available_models.append(k)
|
||||||
|
|
||||||
llm_models = running_models + available_models
|
llm_models = running_models + available_models
|
||||||
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
|
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
|
||||||
llm_model = st.selectbox("选择LLM模型:",
|
llm_model = st.selectbox("选择LLM模型:",
|
||||||
llm_models,
|
llm_models,
|
||||||
index,
|
index,
|
||||||
@ -92,7 +113,7 @@ def dialogue_page(api: ApiRequest):
|
|||||||
key="llm_model",
|
key="llm_model",
|
||||||
)
|
)
|
||||||
if (st.session_state.get("prev_llm_model") != llm_model
|
if (st.session_state.get("prev_llm_model") != llm_model
|
||||||
and not get_model_worker_config(llm_model).get("online_api")
|
and not api.get_model_config(llm_model).get("online_api")
|
||||||
and llm_model not in running_models):
|
and llm_model not in running_models):
|
||||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||||
prev_model = st.session_state.get("prev_llm_model")
|
prev_model = st.session_state.get("prev_llm_model")
|
||||||
@ -114,7 +135,7 @@ def dialogue_page(api: ApiRequest):
|
|||||||
|
|
||||||
if dialogue_mode == "知识库问答":
|
if dialogue_mode == "知识库问答":
|
||||||
with st.expander("知识库配置", True):
|
with st.expander("知识库配置", True):
|
||||||
kb_list = api.list_knowledge_bases(no_remote_api=True)
|
kb_list = api.list_knowledge_bases()
|
||||||
selected_kb = st.selectbox(
|
selected_kb = st.selectbox(
|
||||||
"请选择知识库:",
|
"请选择知识库:",
|
||||||
kb_list,
|
kb_list,
|
||||||
@ -126,7 +147,7 @@ def dialogue_page(api: ApiRequest):
|
|||||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
search_engine_list = list(SEARCH_ENGINES.keys())
|
search_engine_list = api.list_search_engines()
|
||||||
with st.expander("搜索引擎配置", True):
|
with st.expander("搜索引擎配置", True):
|
||||||
search_engine = st.selectbox(
|
search_engine = st.selectbox(
|
||||||
label="请选择搜索引擎",
|
label="请选择搜索引擎",
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
# 该文件包含webui通用工具,可以被不同的webui使用
|
# 该文件封装了对api.py的请求,可以被不同的webui使用
|
||||||
|
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
|
||||||
|
|
||||||
|
|
||||||
from typing import *
|
from typing import *
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
DEFAULT_VS_TYPE,
|
DEFAULT_VS_TYPE,
|
||||||
KB_ROOT_PATH,
|
|
||||||
LLM_MODEL,
|
LLM_MODEL,
|
||||||
HISTORY_LEN,
|
|
||||||
TEMPERATURE,
|
TEMPERATURE,
|
||||||
SCORE_THRESHOLD,
|
SCORE_THRESHOLD,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
@ -14,59 +16,44 @@ from configs import (
|
|||||||
ZH_TITLE_ENHANCE,
|
ZH_TITLE_ENHANCE,
|
||||||
VECTOR_SEARCH_TOP_K,
|
VECTOR_SEARCH_TOP_K,
|
||||||
SEARCH_ENGINE_TOP_K,
|
SEARCH_ENGINE_TOP_K,
|
||||||
FSCHAT_MODEL_WORKERS,
|
|
||||||
HTTPX_DEFAULT_TIMEOUT,
|
HTTPX_DEFAULT_TIMEOUT,
|
||||||
logger, log_verbose,
|
logger, log_verbose,
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
import asyncio
|
|
||||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from server.utils import run_async, iter_over_async, set_httpx_config, api_address, get_httpx_client
|
from server.utils import run_async, set_httpx_config, api_address, get_httpx_client
|
||||||
|
|
||||||
from configs.model_config import NLTK_DATA_PATH
|
|
||||||
import nltk
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
|
||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
|
|
||||||
class ApiRequest:
|
class ApiRequest:
|
||||||
'''
|
'''
|
||||||
api.py调用的封装,主要实现:
|
api.py调用的封装(同步模式),简化api调用方式
|
||||||
1. 简化api调用方式
|
|
||||||
2. 实现无api调用(直接运行server.chat.*中的视图函数获取结果),无需启动api.py
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str = api_address(),
|
base_url: str = api_address(),
|
||||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||||
no_remote_api: bool = False, # call api view function directly
|
|
||||||
):
|
):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.no_remote_api = no_remote_api
|
self._use_async = False
|
||||||
self._client = get_httpx_client()
|
self._client = None
|
||||||
self._aclient = get_httpx_client(use_async=True)
|
|
||||||
if no_remote_api:
|
|
||||||
logger.warn("将来可能取消对no_remote_api的支持,更新版本时请注意。")
|
|
||||||
|
|
||||||
def _parse_url(self, url: str) -> str:
|
@property
|
||||||
if (not url.startswith("http")
|
def client(self):
|
||||||
and self.base_url
|
if self._client is None or self._client.is_closed:
|
||||||
):
|
self._client = get_httpx_client(base_url=self.base_url,
|
||||||
part1 = self.base_url.strip(" /")
|
use_async=self._use_async,
|
||||||
part2 = url.strip(" /")
|
timeout=self.timeout)
|
||||||
return f"{part1}/{part2}"
|
return self._client
|
||||||
else:
|
|
||||||
return url
|
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
@ -75,44 +62,19 @@ class ApiRequest:
|
|||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[httpx.Response, None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
url = self._parse_url(url)
|
|
||||||
kwargs.setdefault("timeout", self.timeout)
|
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
return self._client.stream("GET", url, params=params, **kwargs)
|
return self.client.stream("GET", url, params=params, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self._client.get(url, params=params, **kwargs)
|
return self.client.get(url, params=params, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"error when get {url}: {e}"
|
msg = f"error when get {url}: {e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
async def aget(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
params: Union[Dict, List[Tuple], bytes] = None,
|
|
||||||
retry: int = 3,
|
|
||||||
stream: bool = False,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[httpx.Response, None]:
|
|
||||||
url = self._parse_url(url)
|
|
||||||
kwargs.setdefault("timeout", self.timeout)
|
|
||||||
|
|
||||||
while retry > 0:
|
|
||||||
try:
|
|
||||||
if stream:
|
|
||||||
return await self._aclient.stream("GET", url, params=params, **kwargs)
|
|
||||||
else:
|
|
||||||
return await self._aclient.get(url, params=params, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
msg = f"error when aget {url}: {e}"
|
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
retry -= 1
|
|
||||||
|
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
@ -121,45 +83,19 @@ class ApiRequest:
|
|||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
url = self._parse_url(url)
|
|
||||||
kwargs.setdefault("timeout", self.timeout)
|
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
return self._client.stream("POST", url, data=data, json=json, **kwargs)
|
return self.client.stream("POST", url, data=data, json=json, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self._client.post(url, data=data, json=json, **kwargs)
|
return self.client.post(url, data=data, json=json, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"error when post {url}: {e}"
|
msg = f"error when post {url}: {e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
async def apost(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
data: Dict = None,
|
|
||||||
json: Dict = None,
|
|
||||||
retry: int = 3,
|
|
||||||
stream: bool = False,
|
|
||||||
**kwargs: Any
|
|
||||||
) -> Union[httpx.Response, None]:
|
|
||||||
url = self._parse_url(url)
|
|
||||||
kwargs.setdefault("timeout", self.timeout)
|
|
||||||
|
|
||||||
while retry > 0:
|
|
||||||
try:
|
|
||||||
if stream:
|
|
||||||
return await self._client.stream("POST", url, data=data, json=json, **kwargs)
|
|
||||||
else:
|
|
||||||
return await self._client.post(url, data=data, json=json, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
msg = f"error when apost {url}: {e}"
|
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
retry -= 1
|
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
@ -168,65 +104,19 @@ class ApiRequest:
|
|||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
url = self._parse_url(url)
|
|
||||||
kwargs.setdefault("timeout", self.timeout)
|
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
return self._client.stream("DELETE", url, data=data, json=json, **kwargs)
|
return self.client.stream("DELETE", url, data=data, json=json, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self._client.delete(url, data=data, json=json, **kwargs)
|
return self.client.delete(url, data=data, json=json, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"error when delete {url}: {e}"
|
msg = f"error when delete {url}: {e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
async def adelete(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
data: Dict = None,
|
|
||||||
json: Dict = None,
|
|
||||||
retry: int = 3,
|
|
||||||
stream: bool = False,
|
|
||||||
**kwargs: Any
|
|
||||||
) -> Union[httpx.Response, None]:
|
|
||||||
url = self._parse_url(url)
|
|
||||||
kwargs.setdefault("timeout", self.timeout)
|
|
||||||
|
|
||||||
while retry > 0:
|
|
||||||
try:
|
|
||||||
if stream:
|
|
||||||
return await self._aclient.stream("DELETE", url, data=data, json=json, **kwargs)
|
|
||||||
else:
|
|
||||||
return await self._aclient.delete(url, data=data, json=json, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
msg = f"error when adelete {url}: {e}"
|
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
retry -= 1
|
|
||||||
|
|
||||||
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
|
|
||||||
'''
|
|
||||||
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
except:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
|
|
||||||
try:
|
|
||||||
for chunk in iter_over_async(response.body_iterator, loop):
|
|
||||||
if as_json and chunk:
|
|
||||||
yield json.loads(chunk)
|
|
||||||
elif chunk.strip():
|
|
||||||
yield chunk
|
|
||||||
except Exception as e:
|
|
||||||
msg = f"error when run fastapi router: {e}"
|
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
|
|
||||||
def _httpx_stream2generator(
|
def _httpx_stream2generator(
|
||||||
self,
|
self,
|
||||||
response: contextlib._GeneratorContextManager,
|
response: contextlib._GeneratorContextManager,
|
||||||
@ -235,6 +125,39 @@ class ApiRequest:
|
|||||||
'''
|
'''
|
||||||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||||||
'''
|
'''
|
||||||
|
async def ret_async(response, as_json):
|
||||||
|
try:
|
||||||
|
async with response as r:
|
||||||
|
async for chunk in r.aiter_text(None):
|
||||||
|
if not chunk: # fastchat api yield empty bytes on start and end
|
||||||
|
continue
|
||||||
|
if as_json:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
pprint(data, depth=1)
|
||||||
|
yield data
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||||
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
|
exc_info=e if log_verbose else None)
|
||||||
|
else:
|
||||||
|
# print(chunk, end="", flush=True)
|
||||||
|
yield chunk
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||||||
|
logger.error(msg)
|
||||||
|
yield {"code": 500, "msg": msg}
|
||||||
|
except httpx.ReadTimeout as e:
|
||||||
|
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
||||||
|
logger.error(msg)
|
||||||
|
yield {"code": 500, "msg": msg}
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"API通信遇到错误:{e}"
|
||||||
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
|
exc_info=e if log_verbose else None)
|
||||||
|
yield {"code": 500, "msg": msg}
|
||||||
|
|
||||||
|
def ret_sync(response, as_json):
|
||||||
try:
|
try:
|
||||||
with response as r:
|
with response as r:
|
||||||
for chunk in r.iter_text(None):
|
for chunk in r.iter_text(None):
|
||||||
@ -255,10 +178,9 @@ class ApiRequest:
|
|||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
logger.error(msg)
|
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
except httpx.ReadTimeout as e:
|
except httpx.ReadTimeout as e:
|
||||||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')。({e})"
|
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -267,6 +189,54 @@ class ApiRequest:
|
|||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
|
|
||||||
|
if self._use_async:
|
||||||
|
return ret_async(response, as_json)
|
||||||
|
else:
|
||||||
|
return ret_sync(response, as_json)
|
||||||
|
|
||||||
|
def _get_response_value(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
as_json: bool = False,
|
||||||
|
value_func: Callable = None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
转换同步或异步请求返回的响应
|
||||||
|
`as_json`: 返回json
|
||||||
|
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
||||||
|
'''
|
||||||
|
def to_json(r):
|
||||||
|
try:
|
||||||
|
return r.json()
|
||||||
|
except Exception as e:
|
||||||
|
msg = "API未能返回正确的JSON。" + str(e)
|
||||||
|
if log_verbose:
|
||||||
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
|
exc_info=e if log_verbose else None)
|
||||||
|
return {"code": 500, "msg": msg}
|
||||||
|
|
||||||
|
if value_func is None:
|
||||||
|
value_func = (lambda r: r)
|
||||||
|
|
||||||
|
async def ret_async(response):
|
||||||
|
if as_json:
|
||||||
|
return value_func(to_json(await response))
|
||||||
|
else:
|
||||||
|
return value_func(await response)
|
||||||
|
|
||||||
|
if self._use_async:
|
||||||
|
return ret_async(response)
|
||||||
|
else:
|
||||||
|
if as_json:
|
||||||
|
return value_func(to_json(response))
|
||||||
|
else:
|
||||||
|
return value_func(response)
|
||||||
|
|
||||||
|
# 服务器信息
|
||||||
|
def get_server_configs(self, **kwargs):
|
||||||
|
response = self.post("/server/configs", **kwargs)
|
||||||
|
return self._get_response_value(response, lambda r: r.json())
|
||||||
|
|
||||||
# 对话相关操作
|
# 对话相关操作
|
||||||
|
|
||||||
def chat_fastchat(
|
def chat_fastchat(
|
||||||
@ -275,15 +245,12 @@ class ApiRequest:
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024, # TODO:根据message内容自动计算max_tokens
|
max_tokens: int = 1024,
|
||||||
no_remote_api: bool = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/fastchat接口
|
对应api.py/chat/fastchat接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
msg = OpenAiChatMsgIn(**{
|
msg = OpenAiChatMsgIn(**{
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
@ -293,11 +260,6 @@ class ApiRequest:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
})
|
})
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.chat.openai_chat import openai_chat
|
|
||||||
response = run_async(openai_chat(msg))
|
|
||||||
return self._fastapi_stream2generator(response)
|
|
||||||
else:
|
|
||||||
data = msg.dict(exclude_unset=True, exclude_none=True)
|
data = msg.dict(exclude_unset=True, exclude_none=True)
|
||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
pprint(data)
|
pprint(data)
|
||||||
@ -318,14 +280,11 @@ class ApiRequest:
|
|||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
prompt_name: str = "llm_chat",
|
prompt_name: str = "llm_chat",
|
||||||
no_remote_api: bool = None,
|
**kwargs,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/chat接口
|
对应api.py/chat/chat接口 #TODO: 考虑是否返回json
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
@ -339,12 +298,7 @@ class ApiRequest:
|
|||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
pprint(data)
|
pprint(data)
|
||||||
|
|
||||||
if no_remote_api:
|
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
||||||
from server.chat.chat import chat
|
|
||||||
response = run_async(chat(**data))
|
|
||||||
return self._fastapi_stream2generator(response)
|
|
||||||
else:
|
|
||||||
response = self.post("/chat/chat", json=data, stream=True)
|
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response)
|
||||||
|
|
||||||
def agent_chat(
|
def agent_chat(
|
||||||
@ -355,14 +309,10 @@ class ApiRequest:
|
|||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/agent_chat 接口
|
对应api.py/chat/agent_chat 接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
@ -375,11 +325,6 @@ class ApiRequest:
|
|||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
pprint(data)
|
pprint(data)
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.chat.agent_chat import agent_chat
|
|
||||||
response = run_async(agent_chat(**data))
|
|
||||||
return self._fastapi_stream2generator(response)
|
|
||||||
else:
|
|
||||||
response = self.post("/chat/agent_chat", json=data, stream=True)
|
response = self.post("/chat/agent_chat", json=data, stream=True)
|
||||||
return self._httpx_stream2generator(response)
|
return self._httpx_stream2generator(response)
|
||||||
|
|
||||||
@ -395,14 +340,10 @@ class ApiRequest:
|
|||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
prompt_name: str = "knowledge_base_chat",
|
prompt_name: str = "knowledge_base_chat",
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/knowledge_base_chat接口
|
对应api.py/chat/knowledge_base_chat接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
@ -413,18 +354,12 @@ class ApiRequest:
|
|||||||
"model_name": model,
|
"model_name": model,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"local_doc_url": no_remote_api,
|
|
||||||
"prompt_name": prompt_name,
|
"prompt_name": prompt_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
pprint(data)
|
pprint(data)
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
|
||||||
response = run_async(knowledge_base_chat(**data))
|
|
||||||
return self._fastapi_stream2generator(response, as_json=True)
|
|
||||||
else:
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/knowledge_base_chat",
|
"/chat/knowledge_base_chat",
|
||||||
json=data,
|
json=data,
|
||||||
@ -443,14 +378,10 @@ class ApiRequest:
|
|||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
prompt_name: str = "knowledge_base_chat",
|
prompt_name: str = "knowledge_base_chat",
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/search_engine_chat接口
|
对应api.py/chat/search_engine_chat接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"search_engine_name": search_engine_name,
|
"search_engine_name": search_engine_name,
|
||||||
@ -466,11 +397,6 @@ class ApiRequest:
|
|||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
pprint(data)
|
pprint(data)
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.chat.search_engine_chat import search_engine_chat
|
|
||||||
response = run_async(search_engine_chat(**data))
|
|
||||||
return self._fastapi_stream2generator(response, as_json=True)
|
|
||||||
else:
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/search_engine_chat",
|
"/chat/search_engine_chat",
|
||||||
json=data,
|
json=data,
|
||||||
@ -480,116 +406,65 @@ class ApiRequest:
|
|||||||
|
|
||||||
# 知识库相关操作
|
# 知识库相关操作
|
||||||
|
|
||||||
def _check_httpx_json_response(
|
|
||||||
self,
|
|
||||||
response: httpx.Response,
|
|
||||||
errorMsg: str = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。",
|
|
||||||
) -> Dict:
|
|
||||||
'''
|
|
||||||
check whether httpx returns correct data with normal Response.
|
|
||||||
error in api with streaming support was checked in _httpx_stream2enerator
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
return response.json()
|
|
||||||
except Exception as e:
|
|
||||||
msg = "API未能返回正确的JSON。" + (errorMsg or str(e))
|
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
return {"code": 500, "msg": msg}
|
|
||||||
|
|
||||||
def list_knowledge_bases(
|
def list_knowledge_bases(
|
||||||
self,
|
self,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/list_knowledge_bases接口
|
对应api.py/knowledge_base/list_knowledge_bases接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_api import list_kbs
|
|
||||||
response = list_kbs()
|
|
||||||
return response.data
|
|
||||||
else:
|
|
||||||
response = self.get("/knowledge_base/list_knowledge_bases")
|
response = self.get("/knowledge_base/list_knowledge_bases")
|
||||||
data = self._check_httpx_json_response(response)
|
return self._get_response_value(response,
|
||||||
return data.get("data", [])
|
as_json=True,
|
||||||
|
value_func=lambda r: r.get("data", []))
|
||||||
|
|
||||||
def create_knowledge_base(
|
def create_knowledge_base(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
vector_store_type: str = "faiss",
|
vector_store_type: str = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/create_knowledge_base接口
|
对应api.py/knowledge_base/create_knowledge_base接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"vector_store_type": vector_store_type,
|
"vector_store_type": vector_store_type,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_api import create_kb
|
|
||||||
response = create_kb(**data)
|
|
||||||
return response.dict()
|
|
||||||
else:
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/create_knowledge_base",
|
"/knowledge_base/create_knowledge_base",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def delete_knowledge_base(
|
def delete_knowledge_base(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/delete_knowledge_base接口
|
对应api.py/knowledge_base/delete_knowledge_base接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_api import delete_kb
|
|
||||||
response = delete_kb(knowledge_base_name)
|
|
||||||
return response.dict()
|
|
||||||
else:
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/delete_knowledge_base",
|
"/knowledge_base/delete_knowledge_base",
|
||||||
json=f"{knowledge_base_name}",
|
json=f"{knowledge_base_name}",
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def list_kb_docs(
|
def list_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/list_files接口
|
对应api.py/knowledge_base/list_files接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_doc_api import list_files
|
|
||||||
response = list_files(knowledge_base_name)
|
|
||||||
return response.data
|
|
||||||
else:
|
|
||||||
response = self.get(
|
response = self.get(
|
||||||
"/knowledge_base/list_files",
|
"/knowledge_base/list_files",
|
||||||
params={"knowledge_base_name": knowledge_base_name}
|
params={"knowledge_base_name": knowledge_base_name}
|
||||||
)
|
)
|
||||||
data = self._check_httpx_json_response(response)
|
return self._get_response_value(response,
|
||||||
return data.get("data", [])
|
as_json=True,
|
||||||
|
value_func=lambda r: r.get("data", []))
|
||||||
|
|
||||||
def search_kb_docs(
|
def search_kb_docs(
|
||||||
self,
|
self,
|
||||||
@ -597,14 +472,10 @@ class ApiRequest:
|
|||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: int = SCORE_THRESHOLD,
|
score_threshold: int = SCORE_THRESHOLD,
|
||||||
no_remote_api: bool = None,
|
|
||||||
) -> List:
|
) -> List:
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/search_docs接口
|
对应api.py/knowledge_base/search_docs接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
@ -612,16 +483,11 @@ class ApiRequest:
|
|||||||
"score_threshold": score_threshold,
|
"score_threshold": score_threshold,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_doc_api import search_docs
|
|
||||||
return search_docs(**data)
|
|
||||||
else:
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/search_docs",
|
"/knowledge_base/search_docs",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
data = self._check_httpx_json_response(response)
|
return self._get_response_value(response, as_json=True)
|
||||||
return data
|
|
||||||
|
|
||||||
def upload_kb_docs(
|
def upload_kb_docs(
|
||||||
self,
|
self,
|
||||||
@ -634,14 +500,10 @@ class ApiRequest:
|
|||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
docs: Dict = {},
|
docs: Dict = {},
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/upload_docs接口
|
对应api.py/knowledge_base/upload_docs接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
def convert_file(file, filename=None):
|
def convert_file(file, filename=None):
|
||||||
if isinstance(file, bytes): # raw bytes
|
if isinstance(file, bytes): # raw bytes
|
||||||
file = BytesIO(file)
|
file = BytesIO(file)
|
||||||
@ -664,21 +526,6 @@ class ApiRequest:
|
|||||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_doc_api import upload_docs
|
|
||||||
from fastapi import UploadFile
|
|
||||||
from tempfile import SpooledTemporaryFile
|
|
||||||
|
|
||||||
upload_files = []
|
|
||||||
for filename, file in files:
|
|
||||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
|
||||||
temp_file.write(file.read())
|
|
||||||
temp_file.seek(0)
|
|
||||||
upload_files.append(UploadFile(file=temp_file, filename=filename))
|
|
||||||
|
|
||||||
response = upload_docs(upload_files, **data)
|
|
||||||
return response.dict()
|
|
||||||
else:
|
|
||||||
if isinstance(data["docs"], dict):
|
if isinstance(data["docs"], dict):
|
||||||
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
||||||
response = self.post(
|
response = self.post(
|
||||||
@ -686,7 +533,7 @@ class ApiRequest:
|
|||||||
data=data,
|
data=data,
|
||||||
files=[("files", (filename, file)) for filename, file in files],
|
files=[("files", (filename, file)) for filename, file in files],
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def delete_kb_docs(
|
def delete_kb_docs(
|
||||||
self,
|
self,
|
||||||
@ -694,14 +541,10 @@ class ApiRequest:
|
|||||||
file_names: List[str],
|
file_names: List[str],
|
||||||
delete_content: bool = False,
|
delete_content: bool = False,
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/delete_docs接口
|
对应api.py/knowledge_base/delete_docs接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"file_names": file_names,
|
"file_names": file_names,
|
||||||
@ -709,16 +552,11 @@ class ApiRequest:
|
|||||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_doc_api import delete_docs
|
|
||||||
response = delete_docs(**data)
|
|
||||||
return response.dict()
|
|
||||||
else:
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/delete_docs",
|
"/knowledge_base/delete_docs",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def update_kb_docs(
|
def update_kb_docs(
|
||||||
self,
|
self,
|
||||||
@ -730,14 +568,10 @@ class ApiRequest:
|
|||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
docs: Dict = {},
|
docs: Dict = {},
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_docs接口
|
对应api.py/knowledge_base/update_docs接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"file_names": file_names,
|
"file_names": file_names,
|
||||||
@ -748,18 +582,15 @@ class ApiRequest:
|
|||||||
"docs": docs,
|
"docs": docs,
|
||||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
}
|
}
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_doc_api import update_docs
|
|
||||||
response = update_docs(**data)
|
|
||||||
return response.dict()
|
|
||||||
else:
|
|
||||||
if isinstance(data["docs"], dict):
|
if isinstance(data["docs"], dict):
|
||||||
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
||||||
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/update_docs",
|
"/knowledge_base/update_docs",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def recreate_vector_store(
|
def recreate_vector_store(
|
||||||
self,
|
self,
|
||||||
@ -770,14 +601,10 @@ class ApiRequest:
|
|||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/recreate_vector_store接口
|
对应api.py/knowledge_base/recreate_vector_store接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"allow_empty_kb": allow_empty_kb,
|
"allow_empty_kb": allow_empty_kb,
|
||||||
@ -788,11 +615,6 @@ class ApiRequest:
|
|||||||
"zh_title_enhance": zh_title_enhance,
|
"zh_title_enhance": zh_title_enhance,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.knowledge_base.kb_doc_api import recreate_vector_store
|
|
||||||
response = recreate_vector_store(**data)
|
|
||||||
return self._fastapi_stream2generator(response, as_json=True)
|
|
||||||
else:
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/recreate_vector_store",
|
"/knowledge_base/recreate_vector_store",
|
||||||
json=data,
|
json=data,
|
||||||
@ -805,88 +627,89 @@ class ApiRequest:
|
|||||||
def list_running_models(
|
def list_running_models(
|
||||||
self,
|
self,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
获取Fastchat中正运行的模型列表
|
获取Fastchat中正运行的模型列表
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"controller_address": controller_address,
|
"controller_address": controller_address,
|
||||||
}
|
}
|
||||||
if no_remote_api:
|
|
||||||
from server.llm_api import list_running_models
|
response = self.post(
|
||||||
return list_running_models(**data).data
|
|
||||||
else:
|
|
||||||
r = self.post(
|
|
||||||
"/llm_model/list_running_models",
|
"/llm_model/list_running_models",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return r.json().get("data", [])
|
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
|
||||||
|
|
||||||
def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]:
|
def list_config_models(self) -> Dict[str, List[str]]:
|
||||||
'''
|
'''
|
||||||
获取configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。
|
获取服务器configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。
|
||||||
如果no_remote_api=True, 从运行ApiRequest的机器上获取;否则从运行api.py的机器上获取。
|
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
response = self.post(
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
if no_remote_api:
|
|
||||||
from server.llm_api import list_config_models
|
|
||||||
return list_config_models().data
|
|
||||||
else:
|
|
||||||
r = self.post(
|
|
||||||
"/llm_model/list_config_models",
|
"/llm_model/list_config_models",
|
||||||
)
|
)
|
||||||
return r.json().get("data", {})
|
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
||||||
|
|
||||||
|
def get_model_config(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
) -> Dict:
|
||||||
|
'''
|
||||||
|
获取服务器上模型配置
|
||||||
|
'''
|
||||||
|
data={
|
||||||
|
"model_name": model_name,
|
||||||
|
}
|
||||||
|
response = self.post(
|
||||||
|
"/llm_model/get_model_config",
|
||||||
|
)
|
||||||
|
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
||||||
|
|
||||||
|
def list_search_engines(self) -> List[str]:
|
||||||
|
'''
|
||||||
|
获取服务器支持的搜索引擎
|
||||||
|
'''
|
||||||
|
response = self.post(
|
||||||
|
"/server/list_search_engines",
|
||||||
|
)
|
||||||
|
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
停止某个LLM模型。
|
停止某个LLM模型。
|
||||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"controller_address": controller_address,
|
"controller_address": controller_address,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
response = self.post(
|
||||||
from server.llm_api import stop_llm_model
|
|
||||||
return stop_llm_model(**data).dict()
|
|
||||||
else:
|
|
||||||
r = self.post(
|
|
||||||
"/llm_model/stop",
|
"/llm_model/stop",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return r.json()
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def change_llm_model(
|
def change_llm_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
new_model_name: str,
|
new_model_name: str,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
no_remote_api: bool = None,
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
向fastchat controller请求切换LLM模型。
|
向fastchat controller请求切换LLM模型。
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
|
||||||
no_remote_api = self.no_remote_api
|
|
||||||
|
|
||||||
if not model_name or not new_model_name:
|
if not model_name or not new_model_name:
|
||||||
return
|
return {
|
||||||
|
"code": 500,
|
||||||
|
"msg": f"未指定模型名称"
|
||||||
|
}
|
||||||
|
|
||||||
|
def ret_sync():
|
||||||
running_models = self.list_running_models()
|
running_models = self.list_running_models()
|
||||||
if new_model_name == model_name or new_model_name in running_models:
|
if new_model_name == model_name or new_model_name in running_models:
|
||||||
return {
|
return {
|
||||||
@ -901,7 +724,7 @@ class ApiRequest:
|
|||||||
}
|
}
|
||||||
|
|
||||||
config_models = self.list_config_models()
|
config_models = self.list_config_models()
|
||||||
if new_model_name not in config_models.get("local", []):
|
if new_model_name not in config_models.get("local", {}):
|
||||||
return {
|
return {
|
||||||
"code": 500,
|
"code": 500,
|
||||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
||||||
@ -913,16 +736,55 @@ class ApiRequest:
|
|||||||
"controller_address": controller_address,
|
"controller_address": controller_address,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
response = self.post(
|
||||||
from server.llm_api import change_llm_model
|
|
||||||
return change_llm_model(**data).dict()
|
|
||||||
else:
|
|
||||||
r = self.post(
|
|
||||||
"/llm_model/change",
|
"/llm_model/change",
|
||||||
json=data,
|
json=data,
|
||||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
|
||||||
)
|
)
|
||||||
return r.json()
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
|
async def ret_async():
|
||||||
|
running_models = await self.list_running_models()
|
||||||
|
if new_model_name == model_name or new_model_name in running_models:
|
||||||
|
return {
|
||||||
|
"code": 200,
|
||||||
|
"msg": "无需切换"
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_name not in running_models:
|
||||||
|
return {
|
||||||
|
"code": 500,
|
||||||
|
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
|
||||||
|
}
|
||||||
|
|
||||||
|
config_models = await self.list_config_models()
|
||||||
|
if new_model_name not in config_models.get("local", {}):
|
||||||
|
return {
|
||||||
|
"code": 500,
|
||||||
|
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"new_model_name": new_model_name,
|
||||||
|
"controller_address": controller_address,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.post(
|
||||||
|
"/llm_model/change",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
|
if self._use_async:
|
||||||
|
return ret_async()
|
||||||
|
else:
|
||||||
|
return ret_sync()
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncApiRequest(ApiRequest):
|
||||||
|
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
||||||
|
super().__init__(base_url, timeout)
|
||||||
|
self._use_async = True
|
||||||
|
|
||||||
|
|
||||||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||||||
@ -950,7 +812,8 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
api = ApiRequest(no_remote_api=True)
|
api = ApiRequest()
|
||||||
|
aapi = AsyncApiRequest()
|
||||||
|
|
||||||
# print(api.chat_fastchat(
|
# print(api.chat_fastchat(
|
||||||
# messages=[{"role": "user", "content": "hello"}]
|
# messages=[{"role": "user", "content": "hello"}]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user