2024-03-27 10:12:28 +08:00

237 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
from typing import Optional, Any, Dict
from fastapi import (APIRouter,
FastAPI,
HTTPException,
Response,
Request,
status
)
import logging
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
import json
import pprint
import tiktoken
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest, \
ChatCompletionResponse, ModelList, EmbeddingsResponse, ChatCompletionStreamResponse, FunctionAvailable
from uvicorn import Config, Server
from fastapi.middleware.cors import CORSMiddleware
import multiprocessing as mp
import threading
from sse_starlette import EventSourceResponse
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.utils.generic import dictify, jsonify
from model_providers.core.model_runtime.model_providers import model_provider_factory
logger = logging.getLogger(__name__)
async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest):
try:
response = model_type_instance.invoke(
model=chat_request.model,
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
**chat_request.to_model_parameters_dict()
},
stop=chat_request.stop,
stream=chat_request.stream,
user="abc-123"
)
return response
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e))
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
"""
Bootstrap Server Lifecycle
"""
def __init__(self, host: str, port: int):
super().__init__()
self._host = host
self._port = port
self._router = APIRouter()
self._app = FastAPI()
self._server_thread = None
@classmethod
def from_config(cls, cfg=None):
host = cfg.get("host", "127.0.0.1")
port = cfg.get("port", 20000)
logger.info(f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}")
return cls(host=host, port=port)
def serve(self, logging_conf: Optional[dict] = None):
self._app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
self._router.add_api_route(
"/v1/models",
self.list_models,
response_model=ModelList,
methods=["GET"],
)
self._router.add_api_route(
"/v1/embeddings",
self.create_embeddings,
response_model=EmbeddingsResponse,
status_code=status.HTTP_200_OK,
methods=["POST"],
)
self._router.add_api_route(
"/v1/chat/completions",
self.create_chat_completion,
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
methods=["POST"],
)
self._app.include_router(self._router)
config = Config(
app=self._app, host=self._host, port=self._port, log_config=logging_conf
)
server = Server(config)
def run_server():
server.run()
self._server_thread = threading.Thread(target=run_server)
self._server_thread.start()
async def join(self):
await self._server_thread.join()
def set_app_event(self, started_event: mp.Event = None):
@self._app.on_event("startup")
async def on_startup():
if started_event is not None:
started_event.set()
async def list_models(self, request: Request):
pass
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
logger.info(f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}")
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
client = ZhipuAI(api_key=authorization)
# 判断embeddings_request.input是否为list
input = None
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else:
input = embeddings_request.input
response = client.embeddings.create(
model=embeddings_request.model,
input=input,
)
return EmbeddingsResponse(**dictify(response))
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
logger.info(f"Received chat completion request: {pprint.pformat(chat_request.dict())}")
if os.environ["API_KEY"] is None:
authorization = request.headers.get("Authorization")
authorization = authorization.split("Bearer ")[-1]
else:
authorization = os.environ["API_KEY"]
model_provider_factory.get_providers(provider_name='openai')
provider_instance = model_provider_factory.get_provider_instance('openai')
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
if chat_request.stream:
generator = create_stream_chat_completion(model_type_instance, chat_request)
return EventSourceResponse(generator, media_type="text/event-stream")
else:
response = model_type_instance.invoke(
model='gpt-4',
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'plugin_web_search': True,
},
stop=['you'],
stream=False,
user="abc-123"
)
chat_response = ChatCompletionResponse(**dictify(response))
return chat_response
def run(
cfg: Dict, logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try:
import signal
# 跳过键盘中断使用xoscar的信号处理
signal.signal(signal.SIGINT, lambda *_: None)
api = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg=cfg.get("run_openai_api", {}))
api.set_app_event(started_event=started_event)
api.serve(logging_conf=logging_conf)
async def pool_join_thread():
await api.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise