From 056b15b99b1ac969b40f5357a88332f12103d4b6 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 31 Mar 2024 17:55:57 +0800 Subject: [PATCH] make format --- .../model_providers/bootstrap_web/common.py | 21 ++- .../bootstrap_web/openai_bootstrap_web.py | 145 ++++++++++++------ 2 files changed, 111 insertions(+), 55 deletions(-) diff --git a/model-providers/model_providers/bootstrap_web/common.py b/model-providers/model_providers/bootstrap_web/common.py index 0566e2cd..a06a3064 100644 --- a/model-providers/model_providers/bootstrap_web/common.py +++ b/model-providers/model_providers/bootstrap_web/common.py @@ -2,8 +2,11 @@ import typing from subprocess import Popen from typing import Optional -from model_providers.core.bootstrap.openai_protocol import ChatCompletionStreamResponseChoice, \ - ChatCompletionStreamResponse, Finish +from model_providers.core.bootstrap.openai_protocol import ( + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, + Finish, +) from model_providers.core.utils.generic import jsonify if typing.TYPE_CHECKING: @@ -11,12 +14,14 @@ if typing.TYPE_CHECKING: def create_stream_chunk( - request_id: str, - model: str, - delta: "ChatCompletionMessage", - index: Optional[int] = 0, - finish_reason: Optional[Finish] = None, + request_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional[Finish] = None, ) -> str: - choice = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + choice = ChatCompletionStreamResponseChoice( + index=index, delta=delta, finish_reason=finish_reason + ) chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice]) return jsonify(chunk) diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 7a348002..31b2cb77 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -5,7 +5,18 @@ import multiprocessing as mp import os import pprint import threading -from typing import Any, Dict, Optional, Union, Tuple, Type, List, cast, Generator, AsyncGenerator +from typing import ( + Any, + AsyncGenerator, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware @@ -15,21 +26,36 @@ from uvicorn import Config, Server from model_providers.bootstrap_web.common import create_stream_chunk from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap.openai_protocol import ( + ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionStreamResponse, + ChatMessage, EmbeddingsRequest, EmbeddingsResponse, + Finish, FunctionAvailable, ModelCard, - ModelList, ChatMessage, ChatCompletionMessage, Role, Finish, ChatCompletionResponseChoice, UsageInfo, + ModelList, + Role, + UsageInfo, +) +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, ) -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk - from model_providers.core.model_runtime.entities.message_entities import ( - UserPromptMessage, PromptMessage, AssistantPromptMessage, ToolPromptMessage, SystemPromptMessage, - PromptMessageContent, PromptMessageContentType, TextPromptMessageContent, ImagePromptMessageContent, + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, ) from model_providers.core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -81,7 +107,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: def _create_template_from_message_type( - message_type: str, template: Union[str, list] + message_type: str, template: Union[str, list] ) -> PromptMessage: """Create a message prompt template from a message type and template string. @@ -102,9 +128,7 @@ def _create_template_from_message_type( text: str = tmpl else: text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501 - content.append( - TextPromptMessageContent(data=text) - ) + content.append(TextPromptMessageContent(data=text)) elif isinstance(tmpl, dict) and "image_url" in tmpl: img_template = cast(dict, tmpl)["image_url"] if isinstance(img_template, str): @@ -142,7 +166,7 @@ def _create_template_from_message_type( def _convert_to_message( - message: MessageLikeRepresentation, + message: MessageLikeRepresentation, ) -> Union[PromptMessage]: """Instantiate a message from a variety of message formats. @@ -161,7 +185,9 @@ def _convert_to_message( an instance of a message or a message template """ if isinstance(message, ChatMessage): - _message = _create_template_from_message_type(message.role.to_origin_role(), message.content) + _message = _create_template_from_message_type( + message.role.to_origin_role(), message.content + ) elif isinstance(message, PromptMessage): _message = message @@ -174,16 +200,16 @@ def _convert_to_message( if isinstance(message_type_str, str): _message = _create_template_from_message_type(message_type_str, template) else: - raise ValueError( - f"Expected message type string, got {message_type_str}" - ) + raise ValueError(f"Expected message type string, got {message_type_str}") else: raise NotImplementedError(f"Unsupported message type: {type(message)}") return _message -async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[str, None]: +async def _stream_openai_chat_completion( + response: Generator, +) -> AsyncGenerator[str, None]: request_id, model = None, None for chunk in response: if not isinstance(chunk, LLMResultChunk): @@ -194,27 +220,41 @@ async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[ model = chunk.model if request_id is None: request_id = "request_id" - yield create_stream_chunk(request_id, model, ChatCompletionMessage(role=Role.ASSISTANT, content="")) + yield create_stream_chunk( + request_id, + model, + ChatCompletionMessage(role=Role.ASSISTANT, content=""), + ) new_token = chunk.delta.message.content if new_token: - delta = ChatCompletionMessage(role=Role.value_of(chunk.delta.message.role.to_origin_role()), - content=new_token, - tool_calls=chunk.delta.message.tool_calls) - yield create_stream_chunk(request_id=request_id, - model=model, delta=delta, - index=chunk.delta.index, - finish_reason=chunk.delta.finish_reason) + delta = ChatCompletionMessage( + role=Role.value_of(chunk.delta.message.role.to_origin_role()), + content=new_token, + tool_calls=chunk.delta.message.tool_calls, + ) + yield create_stream_chunk( + request_id=request_id, + model=model, + delta=delta, + index=chunk.delta.index, + finish_reason=chunk.delta.finish_reason, + ) - yield create_stream_chunk(request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP) + yield create_stream_chunk( + request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP + ) yield "[DONE]" async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse: choice = ChatCompletionResponseChoice( - index=0, message=ChatCompletionMessage(**_convert_prompt_message_to_dict(message=response.message)), - finish_reason=Finish.STOP + index=0, + message=ChatCompletionMessage( + **_convert_prompt_message_to_dict(message=response.message) + ), + finish_reason=Finish.STOP, ) usage = UsageInfo( prompt_tokens=response.usage.prompt_tokens, @@ -335,7 +375,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return ModelList(data=models_list) async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" @@ -345,7 +385,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" @@ -354,22 +394,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): model_instance = self._provider_manager.get_model_instance( provider=provider, model_type=ModelType.LLM, model=chat_request.model ) - prompt_messages = [_convert_to_message(message) for message in chat_request.messages] + prompt_messages = [ + _convert_to_message(message) for message in chat_request.messages + ] - tools = [PromptMessageTool(name=f.function.name, - description=f.function.description, - parameters=f.function.parameters - ) - - for f in chat_request.tools] + tools = [ + PromptMessageTool( + name=f.function.name, + description=f.function.description, + parameters=f.function.parameters, + ) + for f in chat_request.tools + ] if chat_request.functions: - tools.extend([PromptMessageTool(name=f.name, - description=f.description, - parameters=f.parameters - ) for f in chat_request.functions]) + tools.extend( + [ + PromptMessageTool( + name=f.name, description=f.description, parameters=f.parameters + ) + for f in chat_request.functions + ] + ) try: - response = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters={**chat_request.to_model_parameters_dict()}, @@ -380,21 +427,25 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): ) if chat_request.stream: - - return EventSourceResponse(_stream_openai_chat_completion(response), media_type="text/event-stream") + return EventSourceResponse( + _stream_openai_chat_completion(response), + media_type="text/event-stream", + ) else: return await _openai_chat_completion(response) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except InvokeError as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: