2024-06-01 12:16:56 +08:00

371 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import multiprocessing
from contextlib import asynccontextmanager
import multiprocessing as mp
import os
import logging
import sys
from multiprocessing import Process
logger = logging.getLogger()
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
pass
from fastapi import FastAPI
import argparse
from typing import List, Dict
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@asynccontextmanager
async def lifespan(app: FastAPI):
if started_event is not None:
started_event.set()
yield
app.router.lifespan_context = lifespan
def run_init_server(
model_platforms_shard: Dict,
started_event: mp.Event = None,
run_mode: str = None,
model_providers_cfg_path: str = None,
provider_host: str = None,
provider_port: int = None):
from chatchat.init_server import init_server
from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG,
MODEL_PROVIDERS_CFG_HOST,
MODEL_PROVIDERS_CFG_PORT)
if model_providers_cfg_path is None:
model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG
if provider_host is None:
provider_host = MODEL_PROVIDERS_CFG_HOST
if provider_port is None:
provider_port = MODEL_PROVIDERS_CFG_PORT
init_server(model_platforms_shard=model_platforms_shard,
started_event=started_event,
model_providers_cfg_path=model_providers_cfg_path,
provider_host=provider_host,
provider_port=provider_port)
def run_api_server(model_platforms_shard: Dict,
started_event: mp.Event = None,
run_mode: str = None):
from chatchat.server.api_server.server_app import create_app
import uvicorn
from chatchat.server.utils import set_httpx_config
from chatchat.configs import MODEL_PLATFORMS, API_SERVER
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}")
set_httpx_config()
app = create_app(run_mode=run_mode)
_set_app_event(app, started_event)
host = API_SERVER["host"]
port = API_SERVER["port"]
uvicorn.run(app, host=host, port=port)
def run_webui(model_platforms_shard: Dict,
started_event: mp.Event = None, run_mode: str = None):
import sys
from chatchat.server.utils import set_httpx_config
from chatchat.configs import MODEL_PLATFORMS, WEBUI_SERVER
if model_platforms_shard.get('provider_platforms'):
MODEL_PLATFORMS.extend(model_platforms_shard.get('provider_platforms'))
logger.info(f"Webui MODEL_PLATFORMS: {MODEL_PLATFORMS}")
set_httpx_config()
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'webui.py')
flag_options = {'server_address': host,
'server_port': port,
'theme_base': 'light',
'theme_primaryColor': '#165dff',
'theme_secondaryBackgroundColor': '#f5f5f5',
'theme_textColor': '#000000',
'global_disableWatchdogWarning': None,
'global_disableWidgetStateDuplicationWarning': None,
'global_showWarningOnDirectExecution': None,
'global_developmentMode': None, 'global_logLevel': None, 'global_unitTest': None,
'global_suppressDeprecationWarnings': None, 'global_minCachedMessageSize': None,
'global_maxCachedMessageAge': None, 'global_storeCachedForwardMessagesInMemory': None,
'global_dataFrameSerialization': None, 'logger_level': None, 'logger_messageFormat': None,
'logger_enableRich': None, 'client_caching': None, 'client_displayEnabled': None,
'client_showErrorDetails': None, 'client_toolbarMode': None, 'client_showSidebarNavigation': None,
'runner_magicEnabled': None, 'runner_installTracer': None, 'runner_fixMatplotlib': None,
'runner_postScriptGC': None, 'runner_fastReruns': None,
'runner_enforceSerializableSessionState': None, 'runner_enumCoercion': None,
'server_folderWatchBlacklist': None, 'server_fileWatcherType': None, 'server_headless': None,
'server_runOnSave': None, 'server_allowRunOnSave': None, 'server_scriptHealthCheckEnabled': None,
'server_baseUrlPath': None, 'server_enableCORS': None, 'server_enableXsrfProtection': None,
'server_maxUploadSize': None, 'server_maxMessageSize': None, 'server_enableArrowTruncation': None,
'server_enableWebsocketCompression': None, 'server_enableStaticServing': None,
'browser_serverAddress': None, 'browser_gatherUsageStats': None, 'browser_serverPort': None,
'server_sslCertFile': None, 'server_sslKeyFile': None, 'ui_hideTopBar': None,
'ui_hideSidebarNav': None, 'magic_displayRootDocString': None,
'magic_displayLastExprIfNoSemicolon': None, 'deprecation_showfileUploaderEncoding': None,
'deprecation_showImageFormat': None, 'deprecation_showPyplotGlobalUse': None,
'theme_backgroundColor': None, 'theme_font': None}
args = []
if run_mode == "lite":
args += [
"--",
"lite",
]
try:
# for streamlit >= 1.12.1
from streamlit.web import bootstrap
except ImportError:
from streamlit import bootstrap
bootstrap.run(script_dir, False, args, flag_options)
started_event.set()
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--all-webui",
action="store_true",
help="run model_providers servers,run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
help="run model_providers servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--api",
action="store_true",
help="run api.py server",
dest="api",
)
parser.add_argument(
"-w",
"--webui",
action="store_true",
help="run webui.py server",
dest="webui",
)
parser.add_argument(
"-i",
"--lite",
action="store_true",
help="以Lite模式运行仅支持在线API的LLM对话、搜索引擎对话",
dest="lite",
)
args = parser.parse_args()
return args, parser
def dump_server_info(after_start=False, args=None):
import platform
import langchain
from chatchat.server.utils import api_address, webui_address
from chatchat.configs import VERSION, TEXT_SPLITTER_NAME, DEFAULT_EMBEDDING_MODEL
print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}")
print("\n")
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
print(f"默认选用的 Embedding 名称: {DEFAULT_EMBEDDING_MODEL}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.api:
print(
f" Chatchat Model providers Server: model_providers_cfg_path_config:{MODEL_PROVIDERS_CFG_PATH_CONFIG}\n"
f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n"
f" provider_host:{MODEL_PROVIDERS_CFG_HOST}\n")
print(f" Chatchat Api Server: {api_address()}")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n")
async def start_main_server():
import time
import signal
from chatchat.configs import LOG_PATH
def handler(signalname):
"""
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
"""
def f(signal_received, frame):
raise KeyboardInterrupt(f"{signalname} received")
return f
# This will be inherited by the child process if it is forked (not spawned)
signal.signal(signal.SIGINT, handler("SIGINT"))
signal.signal(signal.SIGTERM, handler("SIGTERM"))
mp.set_start_method("spawn")
manager = mp.Manager()
run_mode = None
args, parser = parse_args()
if args.all_webui:
args.api = True
args.api_worker = True
args.webui = True
elif args.all_api:
args.api = True
args.api_worker = True
args.webui = False
elif args.api:
args.api = True
args.api_worker = False
args.webui = False
if args.lite:
args.api = True
args.api_worker = False
args.webui = True
dump_server_info(args=args)
if len(sys.argv) > 1:
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {}
def process_count():
return len(processes)
# 定义全局配置变量,使用 Manager 创建共享字典
model_platforms_shard = manager.dict()
model_providers_started = manager.Event()
if args.api:
process = Process(
target=run_init_server,
name=f"Model providers Server",
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started,
run_mode=run_mode),
daemon=True,
)
processes["model_providers"] = process
api_started = manager.Event()
if args.api:
process = Process(
target=run_api_server,
name=f"API Server",
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode),
daemon=False,
)
processes["api"] = process
webui_started = manager.Event()
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server",
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=webui_started, run_mode=run_mode),
daemon=True,
)
processes["webui"] = process
if process_count() == 0:
parser.print_help()
else:
try:
# 保证任务收到SIGINT后能够正常退出
if p := processes.get("model_providers"):
p.start()
p.name = f"{p.name} ({p.pid})"
model_providers_started.wait() # 等待model_providers启动完成
if p := processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成
if p := processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
# 等待所有进程退出
while processes:
for p in processes.values():
p.join(2)
if not p.is_alive():
processes.pop(p.name)
except Exception as e:
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
for p in processes.values():
logger.warning("Sending SIGKILL to %s", p)
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here
if isinstance(p, dict):
for process in p.values():
process.kill()
else:
p.kill()
for p in processes.values():
logger.info("Process status: %s", p)
def main():
# 添加这行代码
cwd = os.getcwd()
sys.path.append(cwd)
multiprocessing.freeze_support()
print("cwd:"+cwd)
from chatchat.server.knowledge_base.migrate import create_tables
create_tables()
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(start_main_server())
if __name__ == "__main__":
main()