diff --git a/server/api_allinone.py b/server/api_allinone.py index 4379d7d2..3be8581e 100644 --- a/server/api_allinone.py +++ b/server/api_allinone.py @@ -11,10 +11,11 @@ python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 """ import sys import os + sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from llm_api_launch import launch_all,parser,controller_args,worker_args,server_args +from llm_api_launch import launch_all, parser, controller_args, worker_args, server_args from api import create_app import uvicorn @@ -23,8 +24,8 @@ parser.add_argument("--api-port", type=int, default=7861) parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) +api_args = ["api-host", "api-port", "ssl_keyfile", "ssl_certfile"] -api_args = ["api-host","api-port","ssl_keyfile","ssl_certfile"] def run_api(host, port, **kwargs): app = create_app() @@ -38,18 +39,19 @@ def run_api(host, port, **kwargs): else: uvicorn.run(app, host=host, port=port) + if __name__ == "__main__": print("Luanching api_allinone,it would take a while, please be patient...") print("正在启动api_allinone,LLM服务启动约3-10分钟,请耐心等待...") # 初始化消息 args = parser.parse_args() args_dict = vars(args) - launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args) + launch_all(args=args, controller_args=controller_args, worker_args=worker_args, server_args=server_args) run_api( - host=args.api_host, - port=args.api_port, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ) + host=args.api_host, + port=args.api_port, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) print("Luanching api_allinone done.") print("api_allinone启动完毕.") diff --git a/server/llm_api_launch.py b/server/llm_api_launch.py index 044bab65..0f7710a9 100644 --- a/server/llm_api_launch.py +++ b/server/llm_api_launch.py @@ -132,7 +132,7 @@ worker_args = [ "gptq-ckpt", "gptq-wbits", "gptq-groupsize", "gptq-act-order", "model-names", "limit-worker-concurrency", "stream-interval", "no-register", - "controller-address","worker-address" + "controller-address", "worker-address" ] # -----------------openai server--------------------------- @@ -159,8 +159,6 @@ server_args = ["server-host", "server-port", "allow-credentials", "api-keys", "controller-address" ] - - # 0,controller, model_worker, openai_api_server # 1, 命令行选项 # 2,LOG_PATH @@ -201,7 +199,7 @@ def string_args(args, args_list): return args_str -def launch_worker(item,args,worker_args=worker_args): +def launch_worker(item, args, worker_args=worker_args): log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_") # 先分割model-path-address,在传到string_args中分析参数 args.model_path, args.worker_host, args.worker_port = item.split("@") @@ -230,11 +228,11 @@ def launch_all(args, subprocess.run(controller_check_sh, shell=True, check=True) print(f"worker启动时间视设备不同而不同,约需3-10分钟,请耐心等待...") if isinstance(args.model_path_address, str): - launch_worker(args.model_path_address,args=args,worker_args=worker_args) + launch_worker(args.model_path_address, args=args, worker_args=worker_args) else: for idx, item in enumerate(args.model_path_address): print(f"开始加载第{idx}个模型:{item}") - launch_worker(item,args=args,worker_args=worker_args) + launch_worker(item, args=args, worker_args=worker_args) server_str_args = string_args(args, server_args) server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server") @@ -244,11 +242,12 @@ def launch_all(args, print("Launching LLM service done!") print("LLM服务启动完毕。") + if __name__ == "__main__": args = parser.parse_args() # 必须要加http//:,否则InvalidSchema: No connection adapters were found args = argparse.Namespace(**vars(args), - **{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"}) + **{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"}) if args.gpus: if len(args.gpus.split(",")) < args.num_gpus: