mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
Merge pull request #4122 from chatchat-space/dev_init_database_providers
Dev init database providers关闭守护进程
This commit is contained in:
commit
bc6832bc7f
@ -1,14 +1,42 @@
|
||||
import sys
|
||||
sys.path.append("chatchat")
|
||||
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
|
||||
from typing import Dict
|
||||
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||
folder2db, prune_db_docs, prune_folder_files)
|
||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL
|
||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS
|
||||
import multiprocessing as mp
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def run_init_model_provider(
|
||||
model_platforms_shard: Dict,
|
||||
started_event: mp.Event = None,
|
||||
model_providers_cfg_path: str = None,
|
||||
provider_host: str = None,
|
||||
provider_port: int = None):
|
||||
from chatchat.init_server import init_server
|
||||
from chatchat.configs import (MODEL_PROVIDERS_CFG_PATH_CONFIG,
|
||||
MODEL_PROVIDERS_CFG_HOST,
|
||||
MODEL_PROVIDERS_CFG_PORT)
|
||||
if model_providers_cfg_path is None:
|
||||
model_providers_cfg_path = MODEL_PROVIDERS_CFG_PATH_CONFIG
|
||||
if provider_host is None:
|
||||
provider_host = MODEL_PROVIDERS_CFG_HOST
|
||||
if provider_port is None:
|
||||
provider_port = MODEL_PROVIDERS_CFG_PORT
|
||||
|
||||
init_server(model_platforms_shard=model_platforms_shard,
|
||||
started_event=started_event,
|
||||
model_providers_cfg_path=model_providers_cfg_path,
|
||||
provider_host=provider_host,
|
||||
provider_port=provider_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
|
||||
|
||||
parser.add_argument(
|
||||
@ -92,27 +120,69 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
start_time = datetime.now()
|
||||
|
||||
if args.create_tables:
|
||||
create_tables() # confirm tables exist
|
||||
mp.set_start_method("spawn")
|
||||
manager = mp.Manager()
|
||||
|
||||
if args.clear_tables:
|
||||
reset_tables()
|
||||
print("database tables reset")
|
||||
# 定义全局配置变量,使用 Manager 创建共享字典
|
||||
model_platforms_shard = manager.dict()
|
||||
model_providers_started = manager.Event()
|
||||
processes = {}
|
||||
process = mp.Process(
|
||||
target=run_init_model_provider,
|
||||
name=f"Model providers Server",
|
||||
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=model_providers_started),
|
||||
daemon=True,
|
||||
)
|
||||
processes["model_providers"] = process
|
||||
try:
|
||||
# 保证任务收到SIGINT后,能够正常退出
|
||||
if p := processes.get("model_providers"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
model_providers_started.wait() # 等待model_providers启动完成
|
||||
MODEL_PLATFORMS.extend(model_platforms_shard['provider_platforms'])
|
||||
logger.info(f"Api MODEL_PLATFORMS: {MODEL_PLATFORMS}")
|
||||
|
||||
if args.recreate_vs:
|
||||
create_tables()
|
||||
print("recreating all vector stores")
|
||||
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
||||
elif args.import_db:
|
||||
import_from_db(args.import_db)
|
||||
elif args.update_in_db:
|
||||
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
|
||||
elif args.increment:
|
||||
folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model)
|
||||
elif args.prune_db:
|
||||
prune_db_docs(args.kb_name)
|
||||
elif args.prune_folder:
|
||||
prune_folder_files(args.kb_name)
|
||||
|
||||
end_time = datetime.now()
|
||||
print(f"总计用时: {end_time-start_time}")
|
||||
if args.create_tables:
|
||||
create_tables() # confirm tables exist
|
||||
|
||||
if args.clear_tables:
|
||||
reset_tables()
|
||||
print("database tables reset")
|
||||
|
||||
if args.recreate_vs:
|
||||
create_tables()
|
||||
print("recreating all vector stores")
|
||||
folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model)
|
||||
elif args.import_db:
|
||||
import_from_db(args.import_db)
|
||||
elif args.update_in_db:
|
||||
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
|
||||
elif args.increment:
|
||||
folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model)
|
||||
elif args.prune_db:
|
||||
prune_db_docs(args.kb_name)
|
||||
elif args.prune_folder:
|
||||
prune_folder_files(args.kb_name)
|
||||
|
||||
end_time = datetime.now()
|
||||
print(f"总计用时: {end_time-start_time}")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
||||
finally:
|
||||
|
||||
for p in processes.values():
|
||||
logger.warning("Sending SIGKILL to %s", p)
|
||||
# Queues and other inter-process communication primitives can break when
|
||||
# process is killed, but we don't care here
|
||||
|
||||
if isinstance(p, dict):
|
||||
for process in p.values():
|
||||
process.kill()
|
||||
else:
|
||||
p.kill()
|
||||
|
||||
for p in processes.values():
|
||||
logger.info("Process status: %s", p)
|
||||
|
||||
@ -284,7 +284,7 @@ async def start_main_server():
|
||||
target=run_api_server,
|
||||
name=f"API Server",
|
||||
kwargs=dict(model_platforms_shard=model_platforms_shard, started_event=api_started, run_mode=run_mode),
|
||||
daemon=True,
|
||||
daemon=False,
|
||||
)
|
||||
processes["api"] = process
|
||||
|
||||
@ -367,4 +367,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user