From 34a416b9411264142568f7d889b44f0dc43c90d1 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Tue, 29 Aug 2023 10:06:09 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96server=5Fconfig=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E9=A1=B9=20(#1293)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update server_config.py: - 在model_config中增加HISTORY_LEN配置参数 - 将server_config中helper function移动到server.utils中 - 统一set_httpx_timeout的定义和调用 * update webui.py: 应用model_config中的配置项:HISTORY_LEN,VECTOR_SEARCH_TOP_K,SEARCH_ENGINE_TOP_K --------- Co-authored-by: liunux4odoo --- configs/model_config.py.example | 3 ++ configs/server_config.py.example | 36 +--------------- server/llm_api.py | 9 +--- server/utils.py | 70 +++++++++++++++++++++++++++++++ startup.py | 14 ++----- tests/api/test_kb_api.py | 2 +- tests/api/test_stream_chat_api.py | 2 +- webui_pages/dialogue/dialogue.py | 8 ++-- webui_pages/utils.py | 13 +----- 9 files changed, 87 insertions(+), 70 deletions(-) diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 308fb8ee..f46dad68 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -80,6 +80,9 @@ llm_model_dict = { # LLM 名称 LLM_MODEL = "chatglm2-6b" +# 历史对话轮数 +HISTORY_LEN = 3 + # LLM 运行设备 LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" diff --git a/configs/server_config.py.example b/configs/server_config.py.example index b0f37bf4..00f94ea5 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -34,8 +34,8 @@ FSCHAT_MODEL_WORKERS = { "port": 20002, "device": LLM_DEVICE, # todo: 多卡加载需要配置的参数 - "gpus": None, # 使用的GPU,以str的格式指定,如"0,1" - "num_gpus": 1, # 使用GPU的数量 + # "gpus": None, # 使用的GPU,以str的格式指定,如"0,1" + # "num_gpus": 1, # 使用GPU的数量 # 以下为非常用参数,可根据需要配置 # "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存 # "load_8bit": False, # 开启8bit量化 @@ -66,35 +66,3 @@ FSCHAT_CONTROLLER = { "port": 20001, "dispatch_method": "shortest_queue", } - - -# 以下不要更改 -def fschat_controller_address() -> str: - host = FSCHAT_CONTROLLER["host"] - port = FSCHAT_CONTROLLER["port"] - return f"http://{host}:{port}" - - -def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: - if model := FSCHAT_MODEL_WORKERS.get(model_name): - host = model["host"] - port = model["port"] - return f"http://{host}:{port}" - - -def fschat_openai_api_address() -> str: - host = FSCHAT_OPENAI_API["host"] - port = FSCHAT_OPENAI_API["port"] - return f"http://{host}:{port}" - - -def api_address() -> str: - host = API_SERVER["host"] - port = API_SERVER["port"] - return f"http://{host}:{port}" - - -def webui_address() -> str: - host = WEBUI_SERVER["host"] - port = WEBUI_SERVER["port"] - return f"http://{host}:{port}" diff --git a/server/llm_api.py b/server/llm_api.py index ab71b3db..7ef5891c 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -5,7 +5,7 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger -from server.utils import MakeFastAPIOffline +from server.utils import MakeFastAPIOffline, set_httpx_timeout host_ip = "0.0.0.0" @@ -15,13 +15,6 @@ openai_api_port = 8888 base_url = "http://127.0.0.1:{}" -def set_httpx_timeout(timeout=60.0): - import httpx - httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout - - def create_controller_app( dispatch_method="shortest_queue", ): diff --git a/server/utils.py b/server/utils.py index 4a887225..167b672f 100644 --- a/server/utils.py +++ b/server/utils.py @@ -5,6 +5,7 @@ import torch from fastapi import FastAPI from pathlib import Path import asyncio +from configs.model_config import LLM_MODEL from typing import Any, Optional @@ -186,3 +187,72 @@ def MakeFastAPIOffline( with_google_fonts=False, redoc_favicon_url=favicon, ) + + +# 从server_config中获取服务信息 +def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: + ''' + 加载model worker的配置项。 + 优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"] + ''' + from configs.server_config import FSCHAT_MODEL_WORKERS + from configs.model_config import llm_model_dict + + config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() + config.update(llm_model_dict.get(model_name, {})) + config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) + return config + + +def fschat_controller_address() -> str: + from configs.server_config import FSCHAT_CONTROLLER + + host = FSCHAT_CONTROLLER["host"] + port = FSCHAT_CONTROLLER["port"] + return f"http://{host}:{port}" + + +def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: + if model := get_model_worker_config(model_name): + host = model["host"] + port = model["port"] + return f"http://{host}:{port}" + return "" + + +def fschat_openai_api_address() -> str: + from configs.server_config import FSCHAT_OPENAI_API + + host = FSCHAT_OPENAI_API["host"] + port = FSCHAT_OPENAI_API["port"] + return f"http://{host}:{port}" + + +def api_address() -> str: + from configs.server_config import API_SERVER + + host = API_SERVER["host"] + port = API_SERVER["port"] + return f"http://{host}:{port}" + + +def webui_address() -> str: + from configs.server_config import WEBUI_SERVER + + host = WEBUI_SERVER["host"] + port = WEBUI_SERVER["port"] + return f"http://{host}:{port}" + + +def set_httpx_timeout(timeout: float = None): + ''' + 设置httpx默认timeout。 + httpx默认timeout是5秒,在请求LLM回答时不够用。 + ''' + import httpx + from configs.server_config import HTTPX_DEFAULT_TIMEOUT + + timeout = timeout or HTTPX_DEFAULT_TIMEOUT + httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout diff --git a/startup.py b/startup.py index 46b886ea..64a3bcca 100644 --- a/startup.py +++ b/startup.py @@ -17,21 +17,15 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ logger from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, - FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, ) + FSCHAT_OPENAI_API, ) +from server.utils import (fschat_controller_address, fschat_model_worker_address, + fschat_openai_api_address, set_httpx_timeout) from server.utils import MakeFastAPIOffline, FastAPI import argparse from typing import Tuple, List from configs import VERSION -def set_httpx_timeout(timeout=60.0): - import httpx - httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout - - def create_controller_app( dispatch_method: str, ) -> FastAPI: @@ -328,7 +322,7 @@ def dump_server_info(after_start=False): import platform import langchain import fastchat - from configs.server_config import api_address, webui_address + from server.utils import api_address, webui_address print("\n") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 56142fae..cefeec0e 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -6,7 +6,7 @@ from pathlib import Path root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) -from configs.server_config import api_address +from server.utils import api_address from configs.model_config import VECTOR_SEARCH_TOP_K from server.knowledge_base.utils import get_kb_path diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index ad9d3d89..4c2d5faf 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -5,7 +5,7 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent.parent)) from configs.model_config import BING_SUBSCRIPTION_KEY -from configs.server_config import API_SERVER, api_address +from server.utils import api_address from pprint import pprint diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 04ece7da..2d5e260b 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -59,9 +59,7 @@ def dialogue_page(api: ApiRequest): on_change=on_mode_change, key="dialogue_mode", ) - history_len = st.number_input("历史对话轮数:", 0, 10, 3) - - # todo: support history len + history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) def on_kb_change(): st.toast(f"已加载知识库: {st.session_state.selected_kb}") @@ -75,7 +73,7 @@ def dialogue_page(api: ApiRequest): on_change=on_kb_change, key="selected_kb", ) - kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3) + kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) # chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) @@ -87,7 +85,7 @@ def dialogue_page(api: ApiRequest): options=search_engine_list, index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0, ) - se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3) + se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K) # Display chat messages from history on app rerun diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 827cb305..58b08e87 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -6,6 +6,7 @@ from configs.model_config import ( DEFAULT_VS_TYPE, KB_ROOT_PATH, LLM_MODEL, + HISTORY_LEN, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, @@ -20,7 +21,7 @@ import json from io import BytesIO from server.db.repository.knowledge_base_repository import get_kb_detail from server.db.repository.knowledge_file_repository import get_file_detail -from server.utils import run_async, iter_over_async +from server.utils import run_async, iter_over_async, set_httpx_timeout from configs.model_config import NLTK_DATA_PATH import nltk @@ -28,16 +29,6 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path from pprint import pprint -def set_httpx_timeout(timeout=60.0): - ''' - 设置httpx默认timeout到60秒。 - httpx默认timeout是5秒,在请求LLM回答时不够用。 - ''' - httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout - httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout - - KB_ROOT_PATH = Path(KB_ROOT_PATH) set_httpx_timeout()