Langchain-Chatchat/startup.py
liunux4odoo d0846f88cc - pydantic 限定为 v1,并统一项目中所有 pydantic 导入路径,为以后升级 v2 做准备
- 重构 api.py:
    - 按模块划分为不同的 router
    - 添加 openai 兼容的转发接口,项目默认使用该接口以实现模型负载均衡
    - 添加 /tools 接口,可以获取/调用编写的 agent tools
    - 移除所有 EmbeddingFuncAdapter,统一改用 get_Embeddings
    - 待办:
        - /chat/chat 接口改为 openai 兼容
        - 添加 /chat/kb_chat 接口,openai 兼容
        - 改变 ntlk/knowledge_base/logs 等数据目录位置
2024-03-06 13:51:34 +08:00

302 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from contextlib import asynccontextmanager
import multiprocessing as mp
import os
import subprocess
import sys
from multiprocessing import Process
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
pass
from configs import (
LOG_PATH,
log_verbose,
logger,
DEFAULT_EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
API_SERVER,
WEBUI_SERVER,
)
from server.utils import FastAPI
from server.knowledge_base.migrate import create_tables
import argparse
from typing import List, Dict
from configs import VERSION
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@asynccontextmanager
async def lifespan(app: FastAPI):
if started_event is not None:
started_event.set()
yield
app.router.lifespan_context = lifespan
def run_api_server(started_event: mp.Event = None, run_mode: str = None):
from server.api_server.server_app import create_app
import uvicorn
from server.utils import set_httpx_config
set_httpx_config()
app = create_app(run_mode=run_mode)
_set_app_event(app, started_event)
host = API_SERVER["host"]
port = API_SERVER["port"]
uvicorn.run(app, host=host, port=port)
def run_webui(started_event: mp.Event = None, run_mode: str = None):
from server.utils import set_httpx_config
set_httpx_config()
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
cmd = ["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
]
if run_mode == "lite":
cmd += [
"--",
"lite",
]
p = subprocess.Popen(cmd)
started_event.set()
p.wait()
def run_loom(started_event: mp.Event = None):
from configs import LOOM_CONFIG
cmd = ["python", "-m", "loom_core.openai_plugins.deploy.local",
"-f", LOOM_CONFIG
]
p = subprocess.Popen(cmd)
started_event.set()
p.wait()
def parse_args() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--all-webui",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py",
dest="all_webui",
)
parser.add_argument(
"--all-api",
action="store_true",
help="run fastchat's controller/openai_api/model_worker servers, run api.py",
dest="all_api",
)
parser.add_argument(
"--api",
action="store_true",
help="run api.py server",
dest="api",
)
parser.add_argument(
"-w",
"--webui",
action="store_true",
help="run webui.py server",
dest="webui",
)
parser.add_argument(
"-i",
"--lite",
action="store_true",
help="以Lite模式运行仅支持在线API的LLM对话、搜索引擎对话",
dest="lite",
)
args = parser.parse_args()
return args, parser
def dump_server_info(after_start=False, args=None):
import platform
import langchain
from server.utils import api_address, webui_address
print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}")
print("\n")
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
print(f"当前Embbedings模型 {DEFAULT_EMBEDDING_MODEL}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n")
async def start_main_server():
import time
import signal
def handler(signalname):
"""
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
"""
def f(signal_received, frame):
raise KeyboardInterrupt(f"{signalname} received")
return f
# This will be inherited by the child process if it is forked (not spawned)
signal.signal(signal.SIGINT, handler("SIGINT"))
signal.signal(signal.SIGTERM, handler("SIGTERM"))
mp.set_start_method("spawn")
manager = mp.Manager()
run_mode = None
args, parser = parse_args()
if args.all_webui:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = False
if args.lite:
args.model_worker = False
run_mode = "lite"
dump_server_info(args=args)
if len(sys.argv) > 1:
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {"online_api": {}, "model_worker": {}}
def process_count():
return len(processes)
loom_started = manager.Event()
# process = Process(
# target=run_loom,
# name=f"run_loom Server",
# kwargs=dict(started_event=loom_started),
# daemon=True,
# )
# processes["run_loom"] = process
api_started = manager.Event()
if args.api:
process = Process(
target=run_api_server,
name=f"API Server",
kwargs=dict(started_event=api_started, run_mode=run_mode),
daemon=True,
)
processes["api"] = process
webui_started = manager.Event()
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server",
kwargs=dict(started_event=webui_started, run_mode=run_mode),
daemon=True,
)
processes["webui"] = process
if process_count() == 0:
parser.print_help()
else:
try:
# 保证任务收到SIGINT后能够正常退出
if p := processes.get("run_loom"):
p.start()
p.name = f"{p.name} ({p.pid})"
loom_started.wait() # 等待Loom启动完成
if p := processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成
if p := processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
# 等待所有进程退出
if p := processes.get("webui"):
p.join()
except Exception as e:
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
for p in processes.values():
logger.warning("Sending SIGKILL to %s", p)
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here
if isinstance(p, dict):
for process in p.values():
process.kill()
else:
p.kill()
for p in processes.values():
logger.info("Process status: %s", p)
if __name__ == "__main__":
create_tables()
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(start_main_server())