Merge pull request #4122 from chatchat-space/dev_init_database_providers

Dev init database providers关闭守护进程
This commit is contained in:
glide-the 2024-06-02 17:17:42 +08:00 committed by GitHub
commit bc6832bc7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 97 additions and 27 deletions

View File

@ -1,11 +1,39 @@
import sys # Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
sys.path.append("chatchat") from typing import Dict
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files) folder2db, prune_db_docs, prune_folder_files)
from chatchat.configs import DEFAULT_EMBEDDING_MODEL from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS
import multiprocessing as mp
import logging
logger = logging.getLogger(__name__)
from datetime import datetime from datetime import datetime
def run_init_model_provider(
model_platforms_shard: Dict,
started_event: mp.Event = None,
model_providers_cfg_path: str = None,
provider_host: str = None,
provider_port: int = None):
from chatchat.init_server import init_server
from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG,
MODEL_PROVIDERS_CFG_HOST,
MODEL_PROVIDERS_CFG_PORT)
if model_providers_cfg_path is None:
model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG
if provider_host is None:
provider_host = MODEL_PROVIDERS_CFG_HOST
if provider_port is None:
provider_port = MODEL_PROVIDERS_CFG_PORT
init_server(model_platforms_shard=model_platforms_shard,
started_event=started_event,
model_providers_cfg_path=model_providers_cfg_path,
provider_host=provider_host,
provider_port=provider_port)
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
@ -92,6 +120,30 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
start_time = datetime.now() start_time = datetime.now()
mp.set_start_method("spawn")
manager = mp.Manager()
# 定义全局配置变量,使用 Manager 创建共享字典
model_platforms_shard = manager.dict()
model_providers_started = manager.Event()
processes = {}
process = mp.Process(
target=run_init_model_provider,
name=f"Model providers Server",
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started),
daemon=True,
)
processes["model_providers"] = process
try:
# 保证任务收到SIGINT后能够正常退出
if p := processes.get("model_providers"):
p.start()
p.name = f"{p.name} ({p.pid})"
model_providers_started.wait() # 等待model_providers启动完成
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}")
if args.create_tables: if args.create_tables:
create_tables() # confirm tables exist create_tables() # confirm tables exist
@ -116,3 +168,21 @@ if __name__ == "__main__":
end_time = datetime.now() end_time = datetime.now()
print(f"总计用时: {end_time-start_time}") print(f"总计用时: {end_time-start_time}")
except Exception as e:
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
for p in processes.values():
logger.warning("Sending SIGKILL to %s", p)
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here
if isinstance(p, dict):
for process in p.values():
process.kill()
else:
p.kill()
for p in processes.values():
logger.info("Process status: %s", p)

View File

@ -284,7 +284,7 @@ async def start_main_server():
target=run_api_server, target=run_api_server,
name=f"API Server", name=f"API Server",
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode), kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode),
daemon=True, daemon=False,
) )
processes["api"] = process processes["api"] = process