优化server_config配置项 (#1293)

* update server_config.py:
- 在model_config中增加HISTORY_LEN配置参数
- 将server_config中helper function移动到server.utils中
- 统一set_httpx_timeout的定义和调用

* update webui.py:
应用model_config中的配置项:HISTORY_LEN,VECTOR_SEARCH_TOP_K,SEARCH_ENGINE_TOP_K

---------

Co-authored-by: liunux4odoo <liunu@qq.com>
This commit is contained in:
liunux4odoo 2023-08-29 10:06:09 +08:00 committed by GitHub
parent ca0ae29fef
commit 34a416b941
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 87 additions and 70 deletions

View File

@ -80,6 +80,9 @@ llm_model_dict = {
# LLM 名称 # LLM 名称
LLM_MODEL = "chatglm2-6b" LLM_MODEL = "chatglm2-6b"
# 历史对话轮数
HISTORY_LEN = 3
# LLM 运行设备 # LLM 运行设备
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

View File

@ -34,8 +34,8 @@ FSCHAT_MODEL_WORKERS = {
"port": 20002, "port": 20002,
"device": LLM_DEVICE, "device": LLM_DEVICE,
# todo: 多卡加载需要配置的参数 # todo: 多卡加载需要配置的参数
"gpus": None, # 使用的GPU以str的格式指定如"0,1" # "gpus": None, # 使用的GPU以str的格式指定如"0,1"
"num_gpus": 1, # 使用GPU的数量 # "num_gpus": 1, # 使用GPU的数量
# 以下为非常用参数,可根据需要配置 # 以下为非常用参数,可根据需要配置
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存 # "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
# "load_8bit": False, # 开启8bit量化 # "load_8bit": False, # 开启8bit量化
@ -66,35 +66,3 @@ FSCHAT_CONTROLLER = {
"port": 20001, "port": 20001,
"dispatch_method": "shortest_queue", "dispatch_method": "shortest_queue",
} }
# 以下不要更改
def fschat_controller_address() -> str:
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := FSCHAT_MODEL_WORKERS.get(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"
def fschat_openai_api_address() -> str:
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"
def api_address() -> str:
host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"
def webui_address() -> str:
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"

View File

@ -5,7 +5,7 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from server.utils import MakeFastAPIOffline from server.utils import MakeFastAPIOffline, set_httpx_timeout
host_ip = "0.0.0.0" host_ip = "0.0.0.0"
@ -15,13 +15,6 @@ openai_api_port = 8888
base_url = "http://127.0.0.1:{}" base_url = "http://127.0.0.1:{}"
def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
def create_controller_app( def create_controller_app(
dispatch_method="shortest_queue", dispatch_method="shortest_queue",
): ):

View File

@ -5,6 +5,7 @@ import torch
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs.model_config import LLM_MODEL
from typing import Any, Optional from typing import Any, Optional
@ -186,3 +187,72 @@ def MakeFastAPIOffline(
with_google_fonts=False, with_google_fonts=False,
redoc_favicon_url=favicon, redoc_favicon_url=favicon,
) )
# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
'''
加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import llm_model_dict
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
return config
def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := get_model_worker_config(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"
return ""
def fschat_openai_api_address() -> str:
from configs.server_config import FSCHAT_OPENAI_API
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"
def api_address() -> str:
from configs.server_config import API_SERVER
host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"
def webui_address() -> str:
from configs.server_config import WEBUI_SERVER
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"
def set_httpx_timeout(timeout: float = None):
'''
设置httpx默认timeout
httpx默认timeout是5秒在请求LLM回答时不够用
'''
import httpx
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
timeout = timeout or HTTPX_DEFAULT_TIMEOUT
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout

View File

@ -17,21 +17,15 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
logger logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address, FSCHAT_OPENAI_API, )
fschat_openai_api_address, ) from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout)
from server.utils import MakeFastAPIOffline, FastAPI from server.utils import MakeFastAPIOffline, FastAPI
import argparse import argparse
from typing import Tuple, List from typing import Tuple, List
from configs import VERSION from configs import VERSION
def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
def create_controller_app( def create_controller_app(
dispatch_method: str, dispatch_method: str,
) -> FastAPI: ) -> FastAPI:
@ -328,7 +322,7 @@ def dump_server_info(after_start=False):
import platform import platform
import langchain import langchain
import fastchat import fastchat
from configs.server_config import api_address, webui_address from server.utils import api_address, webui_address
print("\n") print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)

View File

@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path)) sys.path.append(str(root_path))
from configs.server_config import api_address from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path from server.knowledge_base.utils import get_kb_path

View File

@ -5,7 +5,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent)) sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.model_config import BING_SUBSCRIPTION_KEY from configs.model_config import BING_SUBSCRIPTION_KEY
from configs.server_config import API_SERVER, api_address from server.utils import api_address
from pprint import pprint from pprint import pprint

View File

@ -59,9 +59,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change, on_change=on_mode_change,
key="dialogue_mode", key="dialogue_mode",
) )
history_len = st.number_input("历史对话轮数:", 0, 10, 3) history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
# todo: support history len
def on_kb_change(): def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}") st.toast(f"已加载知识库: {st.session_state.selected_kb}")
@ -75,7 +73,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_kb_change, on_change=on_kb_change,
key="selected_kb", key="selected_kb",
) )
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3) kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
@ -87,7 +85,7 @@ def dialogue_page(api: ApiRequest):
options=search_engine_list, options=search_engine_list,
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0, index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
) )
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3) se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
# Display chat messages from history on app rerun # Display chat messages from history on app rerun

View File

@ -6,6 +6,7 @@ from configs.model_config import (
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
KB_ROOT_PATH, KB_ROOT_PATH,
LLM_MODEL, LLM_MODEL,
HISTORY_LEN,
SCORE_THRESHOLD, SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
@ -20,7 +21,7 @@ import json
from io import BytesIO from io import BytesIO
from server.db.repository.knowledge_base_repository import get_kb_detail from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async from server.utils import run_async, iter_over_async, set_httpx_timeout
from configs.model_config import NLTK_DATA_PATH from configs.model_config import NLTK_DATA_PATH
import nltk import nltk
@ -28,16 +29,6 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from pprint import pprint from pprint import pprint
def set_httpx_timeout(timeout=60.0):
'''
设置httpx默认timeout到60秒
httpx默认timeout是5秒在请求LLM回答时不够用
'''
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
KB_ROOT_PATH = Path(KB_ROOT_PATH) KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout() set_httpx_timeout()