import asyncio import os from typing import Optional, Any, Dict from fastapi import (APIRouter, FastAPI, HTTPException, Response, Request, status ) import logging from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb import json import pprint import tiktoken from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest, \ ChatCompletionResponse, ModelList, EmbeddingsResponse, ChatCompletionStreamResponse, FunctionAvailable from uvicorn import Config, Server from fastapi.middleware.cors import CORSMiddleware import multiprocessing as mp import threading from sse_starlette import EventSourceResponse from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from model_providers.core.utils.generic import dictify, jsonify from model_providers.core.model_runtime.model_providers import model_provider_factory logger = logging.getLogger(__name__) async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest): try: response = model_type_instance.invoke( model=chat_request.model, credentials={ 'openai_api_key': "sk-", 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') }, prompt_messages=[ UserPromptMessage( content='北京今天的天气怎么样' ) ], model_parameters={ **chat_request.to_model_parameters_dict() }, stop=chat_request.stop, stream=chat_request.stream, user="abc-123" ) return response except Exception as e: logger.exception(e) raise HTTPException(status_code=500, detail=str(e)) class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): """ Bootstrap Server Lifecycle """ def __init__(self, host: str, port: int): super().__init__() self._host = host self._port = port self._router = APIRouter() self._app = FastAPI() self._server_thread = None @classmethod def from_config(cls, cfg=None): host = cfg.get("host", "127.0.0.1") port = cfg.get("port", 20000) logger.info(f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}") return cls(host=host, port=port) def serve(self, logging_conf: Optional[dict] = None): self._app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) self._router.add_api_route( "/v1/models", self.list_models, response_model=ModelList, methods=["GET"], ) self._router.add_api_route( "/v1/embeddings", self.create_embeddings, response_model=EmbeddingsResponse, status_code=status.HTTP_200_OK, methods=["POST"], ) self._router.add_api_route( "/v1/chat/completions", self.create_chat_completion, response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK, methods=["POST"], ) self._app.include_router(self._router) config = Config( app=self._app, host=self._host, port=self._port, log_config=logging_conf ) server = Server(config) def run_server(): server.run() self._server_thread = threading.Thread(target=run_server) self._server_thread.start() async def join(self): await self._server_thread.join() def set_app_event(self, started_event: mp.Event = None): @self._app.on_event("startup") async def on_startup(): if started_event is not None: started_event.set() async def list_models(self, request: Request): pass async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest): logger.info(f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}") if os.environ["API_KEY"] is None: authorization = request.headers.get("Authorization") authorization = authorization.split("Bearer ")[-1] else: authorization = os.environ["API_KEY"] client = ZhipuAI(api_key=authorization) # 判断embeddings_request.input是否为list input = None if isinstance(embeddings_request.input, list): tokens = embeddings_request.input try: encoding = tiktoken.encoding_for_model(embeddings_request.model) except KeyError: logger.warning("Warning: model not found. Using cl100k_base encoding.") model = "cl100k_base" encoding = tiktoken.get_encoding(model) for i, token in enumerate(tokens): text = encoding.decode(token) input += text else: input = embeddings_request.input response = client.embeddings.create( model=embeddings_request.model, input=input, ) return EmbeddingsResponse(**dictify(response)) async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest): logger.info(f"Received chat completion request: {pprint.pformat(chat_request.dict())}") if os.environ["API_KEY"] is None: authorization = request.headers.get("Authorization") authorization = authorization.split("Bearer ")[-1] else: authorization = os.environ["API_KEY"] model_provider_factory.get_providers(provider_name='openai') provider_instance = model_provider_factory.get_provider_instance('openai') model_type_instance = provider_instance.get_model_instance(ModelType.LLM) if chat_request.stream: generator = create_stream_chat_completion(model_type_instance, chat_request) return EventSourceResponse(generator, media_type="text/event-stream") else: response = model_type_instance.invoke( model='gpt-4', credentials={ 'openai_api_key': "sk-", 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') }, prompt_messages=[ UserPromptMessage( content='北京今天的天气怎么样' ) ], model_parameters={ 'temperature': 0.7, 'top_p': 1.0, 'top_k': 1, 'plugin_web_search': True, }, stop=['you'], stream=False, user="abc-123" ) chat_response = ChatCompletionResponse(**dictify(response)) return chat_response def run( cfg: Dict, logging_conf: Optional[dict] = None, started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: import signal # 跳过键盘中断,使用xoscar的信号处理 signal.signal(signal.SIGINT, lambda *_: None) api = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg=cfg.get("run_openai_api", {})) api.set_app_event(started_event=started_event) api.serve(logging_conf=logging_conf) async def pool_join_thread(): await api.join() asyncio.run(pool_join_thread()) except SystemExit: logger.info("SystemExit raised, exiting") raise