import argparse import os from typing import Literal from fastapi import FastAPI, Body from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from starlette.responses import RedirectResponse import uvicorn from chatchat.configs import VERSION, MEDIA_PATH, CHATCHAT_ROOT from chatchat.configs import OPEN_CROSS_DOMAIN from chatchat.server.api_server.chat_routes import chat_router from chatchat.server.api_server.kb_routes import kb_router from chatchat.server.api_server.openai_routes import openai_router from chatchat.server.api_server.server_routes import server_router from chatchat.server.api_server.tool_routes import tool_router from chatchat.server.chat.completion import completion from chatchat.server.utils import MakeFastAPIOffline def create_app(run_mode: str=None): app = FastAPI( title="Langchain-Chatchat API Server", version=VERSION ) MakeFastAPIOffline(app) # Add CORS middleware to allow all origins # 在config.py中设置OPEN_DOMAIN=True,允许跨域 # set OPEN_DOMAIN=True in config.py to allow cross-domain if OPEN_CROSS_DOMAIN: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", summary="swagger 文档", include_in_schema=False) async def document(): return RedirectResponse(url="/docs") app.include_router(chat_router) app.include_router(kb_router) app.include_router(tool_router) app.include_router(openai_router) app.include_router(server_router) # 其它接口 app.post("/other/completion", tags=["Other"], summary="要求llm模型补全(通过LLMChain)", )(completion) # 媒体文件 app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") # 项目相关图片 img_dir = os.path.join(CHATCHAT_ROOT, "img") app.mount("/img", StaticFiles(directory=img_dir), name="img") return app def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): uvicorn.run(app, host=host, port=port, ssl_keyfile=kwargs.get("ssl_keyfile"), ssl_certfile=kwargs.get("ssl_certfile"), ) else: uvicorn.run(app, host=host, port=port) app = create_app() if __name__ == "__main__": parser = argparse.ArgumentParser(prog='langchain-ChatGLM', description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain' ' | 基于本地知识库的 ChatGLM 问答') parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7861) parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) # 初始化消息 args = parser.parse_args() args_dict = vars(args) run_api(host=args.host, port=args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, )