优化server_config配置项 (#1293)

* update server_config.py:
- 在model_config中增加HISTORY_LEN配置参数
- 将server_config中helper function移动到server.utils中
- 统一set_httpx_timeout的定义和调用

* update webui.py:
应用model_config中的配置项:HISTORY_LEN,VECTOR_SEARCH_TOP_K,SEARCH_ENGINE_TOP_K

---------

Co-authored-by: liunux4odoo <liunu@qq.com>
This commit is contained in:
liunux4odoo 2023-08-29 10:06:09 +08:00 committed by GitHub
parent ca0ae29fef
commit 34a416b941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 87 additions and 70 deletions

View File

@ -80,6 +80,9 @@ llm_model_dict = {
# LLM 名称
LLM_MODEL = "chatglm2-6b"
# 历史对话轮数
HISTORY_LEN = 3
# LLM 运行设备
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

View File

@ -34,8 +34,8 @@ FSCHAT_MODEL_WORKERS = {
"port": 20002,
"device": LLM_DEVICE,
# todo: 多卡加载需要配置的参数
"gpus": None, # 使用的GPU以str的格式指定如"0,1"
"num_gpus": 1, # 使用GPU的数量
# "gpus": None, # 使用的GPU以str的格式指定如"0,1"
# "num_gpus": 1, # 使用GPU的数量
# 以下为非常用参数,可根据需要配置
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
# "load_8bit": False, # 开启8bit量化
@ -66,35 +66,3 @@ FSCHAT_CONTROLLER = {
"port": 20001,
"dispatch_method": "shortest_queue",
}
# 以下不要更改
def fschat_controller_address() -> str:
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := FSCHAT_MODEL_WORKERS.get(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"
def fschat_openai_api_address() -> str:
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"
def api_address() -> str:
host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"
def webui_address() -> str:
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"

View File

@ -5,7 +5,7 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from server.utils import MakeFastAPIOffline
from server.utils import MakeFastAPIOffline, set_httpx_timeout
host_ip = "0.0.0.0"
@ -15,13 +15,6 @@ openai_api_port = 8888
base_url = "http://127.0.0.1:{}"
def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
def create_controller_app(
dispatch_method="shortest_queue",
):

View File

@ -5,6 +5,7 @@ import torch
from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs.model_config import LLM_MODEL
from typing import Any, Optional
@ -186,3 +187,72 @@ def MakeFastAPIOffline(
with_google_fonts=False,
redoc_favicon_url=favicon,
)
# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
'''
加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import llm_model_dict
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
return config
def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := get_model_worker_config(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"
return ""
def fschat_openai_api_address() -> str:
from configs.server_config import FSCHAT_OPENAI_API
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"
def api_address() -> str:
from configs.server_config import API_SERVER
host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"
def webui_address() -> str:
from configs.server_config import WEBUI_SERVER
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"
def set_httpx_timeout(timeout: float = None):
'''
设置httpx默认timeout
httpx默认timeout是5秒在请求LLM回答时不够用
'''
import httpx
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
timeout = timeout or HTTPX_DEFAULT_TIMEOUT
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout

View File

@ -17,21 +17,15 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, )
FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout)
from server.utils import MakeFastAPIOffline, FastAPI
import argparse
from typing import Tuple, List
from configs import VERSION
def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
def create_controller_app(
dispatch_method: str,
) -> FastAPI:
@ -328,7 +322,7 @@ def dump_server_info(after_start=False):
import platform
import langchain
import fastchat
from configs.server_config import api_address, webui_address
from server.utils import api_address, webui_address
print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)

View File

@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path

View File

@ -5,7 +5,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.model_config import BING_SUBSCRIPTION_KEY
from configs.server_config import API_SERVER, api_address
from server.utils import api_address
from pprint import pprint

View File

@ -59,9 +59,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change,
key="dialogue_mode",
)
history_len = st.number_input("历史对话轮数:", 0, 10, 3)
# todo: support history len
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
@ -75,7 +73,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_kb_change,
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
@ -87,7 +85,7 @@ def dialogue_page(api: ApiRequest):
options=search_engine_list,
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
# Display chat messages from history on app rerun

View File

@ -6,6 +6,7 @@ from configs.model_config import (
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
HISTORY_LEN,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
@ -20,7 +21,7 @@ import json
from io import BytesIO
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async
from server.utils import run_async, iter_over_async, set_httpx_timeout
from configs.model_config import NLTK_DATA_PATH
import nltk
@ -28,16 +29,6 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from pprint import pprint
def set_httpx_timeout(timeout=60.0):
'''
设置httpx默认timeout到60秒
httpx默认timeout是5秒在请求LLM回答时不够用
'''
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout()