diff --git a/.gitignore b/.gitignore index a411f1a0..124c04e3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ *.log.* logs .idea/ -__pycache__/ \ No newline at end of file +__pycache__/ +knowledge_base/ diff --git a/api_one.py b/api_one.py new file mode 100644 index 00000000..542a8f0a --- /dev/null +++ b/api_one.py @@ -0,0 +1,157 @@ +import sys +import os +from configs.model_config import ( + llm_model_dict, + LLM_MODEL, + EMBEDDING_DEVICE, + LLM_DEVICE, + LOG_PATH, + logger, +) +from fastchat.serve.model_worker import heart_beat_worker +from fastapi import FastAPI +import threading +import asyncio + +# 把fastchat的3个服务端整合在一起,分别是: +# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller +# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ... +# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ... + + +host_ip = "0.0.0.0" +port = 8888 +base_url = f"http://127.0.0.1:{port}" + + +def create_controller_app( + dispatch_method="shortest_queue", +): + from fastchat.serve.controller import app, Controller + controller = Controller(dispatch_method) + sys.modules["fastchat.serve.controller"].controller = controller + logger.info(f"controller dispatch method: {dispatch_method}") + return app, controller + + +def create_model_worker_app( + model_path, + model_names=[LLM_MODEL], + device=LLM_DEVICE, + load_8bit=False, + gptq_ckpt=None, + gptq_wbits=16, + gpus='', + num_gpus=-1, + max_gpu_memory=-1, + cpu_offloading=None, + worker_address=f"{base_url}/model_worker", + controller_address=f"{base_url}/controller", + limit_model_concurrency=5, + stream_interval=2, + no_register=True, # mannually register +): + from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id + from fastchat.serve import model_worker + if gpus and num_gpus is None: + num_gpus = len(gpus.split(',')) + gptq_config = GptqConfig( + ckpt=gptq_ckpt or model_path, + wbits=gptq_wbits, + groupsize=-1, + act_order=None, + ) + worker = ModelWorker( + controller_address, + worker_address, + worker_id, + no_register, + model_path, + model_names, + device, + num_gpus, + max_gpu_memory, + load_8bit, + cpu_offloading, + gptq_config, + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + return app, worker + + +def create_openai_api_app( + host=host_ip, + port=port, + controller_address=f"{base_url}/controller", + allow_credentials=None, + allowed_origins=["*"], + allowed_methods=["*"], + allowed_headers=["*"], + api_keys=[], +): + from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings + app.add_middleware( + CORSMiddleware, + allow_origins=allowed_origins, + allow_credentials=allow_credentials, + allow_methods=allowed_methods, + allow_headers=allowed_headers, + ) + app_settings.controller_address = controller_address + app_settings.api_keys = api_keys + sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings + + return app + + +LLM_MODEL = 'chatglm-6b' +model_path = llm_model_dict[LLM_MODEL]["local_model_path"] +global controller + + +if not model_path: + logger.error("local_model_path 不能为空") +else: + logger.info(f"using local model: {model_path}") + app = FastAPI() + + controller_app, controller = create_controller_app() + app.mount("/controller", controller_app) + + model_woker_app, worker = create_model_worker_app(model_path) + app.mount("/model_worker", model_woker_app) + + openai_api_app = create_openai_api_app() + app.mount("/openai", openai_api_app) + + + @app.on_event("startup") + async def on_startup(): + logger.info("Register to controller") + controller.register_worker( + worker.worker_addr, + True, + worker.get_status(), + ) + worker.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(worker,) + ) + worker.heart_beat_thread.start() + + # 通过网络请求注册会卡死 + # async def register(): + # while True: + # try: + # worker.register_to_controller() + # worker.heart_beat_thread = threading.Thread( + # target=heart_beat_worker, args=(worker,) + # ) + # worker.heart_beat_thread.start() + # except: + # await asyncio.sleep(1) + # asyncio.get_event_loop().create_task(register()) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host=host_ip, port=port)