diff --git a/libs/chatchat-server/chatchat/init_database.py b/libs/chatchat-server/chatchat/init_database.py index 67c28963..7a1baec7 100644 --- a/libs/chatchat-server/chatchat/init_database.py +++ b/libs/chatchat-server/chatchat/init_database.py @@ -1,14 +1,42 @@ -import sys -sys.path.append("chatchat") +# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作 +from typing import Dict from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, folder2db, prune_db_docs, prune_folder_files) -from chatchat.configs import DEFAULT_EMBEDDING_MODEL +from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS +import multiprocessing as mp +import logging +logger = logging.getLogger(__name__) + from datetime import datetime +def run_init_model_provider( + model_platforms_shard: Dict, + started_event: mp.Event = None, + model_providers_cfg_path: str = None, + provider_host: str = None, + provider_port: int = None): + from chatchat.init_server import init_server + from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG, + MODEL_PROVIDERS_CFG_HOST, + MODEL_PROVIDERS_CFG_PORT) + if model_providers_cfg_path is None: + model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG + if provider_host is None: + provider_host = MODEL_PROVIDERS_CFG_HOST + if provider_port is None: + provider_port = MODEL_PROVIDERS_CFG_PORT + + init_server(model_platforms_shard=model_platforms_shard, + started_event=started_event, + model_providers_cfg_path=model_providers_cfg_path, + provider_host=provider_host, + provider_port=provider_port) + + if __name__ == "__main__": import argparse - + parser = argparse.ArgumentParser(description="please specify only one operate method once time.") parser.add_argument( @@ -92,27 +120,69 @@ if __name__ == "__main__": args = parser.parse_args() start_time = datetime.now() - if args.create_tables: - create_tables() # confirm tables exist + mp.set_start_method("spawn") + manager = mp.Manager() - if args.clear_tables: - reset_tables() - print("database tables reset") + # 定义全局配置变量,使用 Manager 创建共享字典 + model_platforms_shard = manager.dict() + model_providers_started = manager.Event() + processes = {} + process = mp.Process( + target=run_init_model_provider, + name=f"Model providers Server", + kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started), + daemon=True, + ) + processes["model_providers"] = process + try: + # 保证任务收到SIGINT后,能够正常退出 + if p := processes.get("model_providers"): + p.start() + p.name = f"{p.name} ({p.pid})" + model_providers_started.wait() # 等待model_providers启动完成 + MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms']) + logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}") - if args.recreate_vs: - create_tables() - print("recreating all vector stores") - folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) - elif args.import_db: - import_from_db(args.import_db) - elif args.update_in_db: - folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) - elif args.increment: - folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model) - elif args.prune_db: - prune_db_docs(args.kb_name) - elif args.prune_folder: - prune_folder_files(args.kb_name) - end_time = datetime.now() - print(f"总计用时: {end_time-start_time}") + if args.create_tables: + create_tables() # confirm tables exist + + if args.clear_tables: + reset_tables() + print("database tables reset") + + if args.recreate_vs: + create_tables() + print("recreating all vector stores") + folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) + elif args.import_db: + import_from_db(args.import_db) + elif args.update_in_db: + folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) + elif args.increment: + folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model) + elif args.prune_db: + prune_db_docs(args.kb_name) + elif args.prune_folder: + prune_folder_files(args.kb_name) + + end_time = datetime.now() + print(f"总计用时: {end_time-start_time}") + except Exception as e: + logger.error(e) + logger.warning("Caught KeyboardInterrupt! Setting stop event...") + finally: + + for p in processes.values(): + logger.warning("Sending SIGKILL to %s", p) + # Queues and other inter-process communication primitives can break when + # process is killed, but we don't care here + + if isinstance(p, dict): + for process in p.values(): + process.kill() + else: + p.kill() + + for p in processes.values(): + logger.info("Process status: %s", p) diff --git a/libs/chatchat-server/chatchat/startup.py b/libs/chatchat-server/chatchat/startup.py index 01a7b10c..992864b3 100644 --- a/libs/chatchat-server/chatchat/startup.py +++ b/libs/chatchat-server/chatchat/startup.py @@ -284,7 +284,7 @@ async def start_main_server(): target=run_api_server, name=f"API Server", kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode), - daemon=True, + daemon=False, ) processes["api"] = process @@ -367,4 +367,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()