diff --git a/model-providers/model_providers/bootstrap_web/common.py b/model-providers/model_providers/bootstrap_web/common.py new file mode 100644 index 00000000..0566e2cd --- /dev/null +++ b/model-providers/model_providers/bootstrap_web/common.py @@ -0,0 +1,22 @@ +import typing +from subprocess import Popen +from typing import Optional + +from model_providers.core.bootstrap.openai_protocol import ChatCompletionStreamResponseChoice, \ + ChatCompletionStreamResponse, Finish +from model_providers.core.utils.generic import jsonify + +if typing.TYPE_CHECKING: + from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage + + +def create_stream_chunk( + request_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional[Finish] = None, +) -> str: + choice = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice]) + return jsonify(chunk) diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 6e692cee..7a348002 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -5,14 +5,14 @@ import multiprocessing as mp import os import pprint import threading -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, Tuple, Type, List, cast, Generator, AsyncGenerator -import tiktoken from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware from sse_starlette import EventSourceResponse from uvicorn import Config, Server +from model_providers.bootstrap_web.common import create_stream_chunk from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap.openai_protocol import ( ChatCompletionRequest, @@ -22,20 +22,212 @@ from model_providers.core.bootstrap.openai_protocol import ( EmbeddingsResponse, FunctionAvailable, ModelCard, - ModelList, + ModelList, ChatMessage, ChatCompletionMessage, Role, Finish, ChatCompletionResponseChoice, UsageInfo, ) -from model_providers.core.model_manager import ModelInstance, ModelManager +from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk + from model_providers.core.model_runtime.entities.message_entities import ( - UserPromptMessage, + UserPromptMessage, PromptMessage, AssistantPromptMessage, ToolPromptMessage, SystemPromptMessage, + PromptMessageContent, PromptMessageContentType, TextPromptMessageContent, ImagePromptMessageContent, + PromptMessageTool, ) from model_providers.core.model_runtime.entities.model_entities import ( AIModelEntity, ModelType, ) +from model_providers.core.model_runtime.errors.invoke import InvokeError from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) +MessageLike = Union[ChatMessage, PromptMessage] + +MessageLikeRepresentation = Union[ + MessageLike, + Tuple[Union[str, Type], Union[str, List[dict], List[object]]], + str, +] + + +def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI Compatibility API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + raise ValueError("User message content must be str") + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls and len(message.tool_calls) > 0: + message_dict["function_call"] = { + "name": message.tool_calls[0].function.name, + "arguments": message.tool_calls[0].function.arguments, + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + # check if last message is user message + message = cast(ToolPromptMessage, message) + message_dict = {"role": "function", "content": message.content} + else: + raise ValueError(f"Unknown message type {type(message)}") + + return message_dict + + +def _create_template_from_message_type( + message_type: str, template: Union[str, list] +) -> PromptMessage: + """Create a message prompt template from a message type and template string. + + Args: + message_type: str the type of the message template (e.g., "human", "ai", etc.) + template: str the template string. + + Returns: + a message prompt template of the appropriate type. + """ + if isinstance(template, str): + content = template + elif isinstance(template, list): + content = [] + for tmpl in template: + if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl: + if isinstance(tmpl, str): + text: str = tmpl + else: + text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501 + content.append( + TextPromptMessageContent(data=text) + ) + elif isinstance(tmpl, dict) and "image_url" in tmpl: + img_template = cast(dict, tmpl)["image_url"] + if isinstance(img_template, str): + img_template_obj = ImagePromptMessageContent(data=img_template) + elif isinstance(img_template, dict): + img_template = dict(img_template) + if "url" in img_template: + url = img_template["url"] + else: + url = None + img_template_obj = ImagePromptMessageContent(data=url) + else: + raise ValueError() + content.append(img_template_obj) + else: + raise ValueError() + else: + raise ValueError() + + if message_type in ("human", "user"): + _message = UserPromptMessage(content=content) + elif message_type in ("ai", "assistant"): + _message = AssistantPromptMessage(content=content) + elif message_type == "system": + _message = SystemPromptMessage(content=content) + elif message_type in ("function", "tool"): + _message = ToolPromptMessage(content=content) + else: + raise ValueError( + f"Unexpected message type: {message_type}. Use one of 'human'," + f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'." + ) + + return _message + + +def _convert_to_message( + message: MessageLikeRepresentation, +) -> Union[PromptMessage]: + """Instantiate a message from a variety of message formats. + + The message format can be one of the following: + + - BaseMessagePromptTemplate + - BaseMessage + - 2-tuple of (role string, template); e.g., ("human", "{user_input}") + - 2-tuple of (message class, template) + - string: shorthand for ("human", template); e.g., "{user_input}" + + Args: + message: a representation of a message in one of the supported formats + + Returns: + an instance of a message or a message template + """ + if isinstance(message, ChatMessage): + _message = _create_template_from_message_type(message.role.to_origin_role(), message.content) + + elif isinstance(message, PromptMessage): + _message = message + elif isinstance(message, str): + _message = _create_template_from_message_type("human", message) + elif isinstance(message, tuple): + if len(message) != 2: + raise ValueError(f"Expected 2-tuple of (role, template), got {message}") + message_type_str, template = message + if isinstance(message_type_str, str): + _message = _create_template_from_message_type(message_type_str, template) + else: + raise ValueError( + f"Expected message type string, got {message_type_str}" + ) + else: + raise NotImplementedError(f"Unsupported message type: {type(message)}") + + return _message + + +async def _stream_openai_chat_completion(response: Generator) -> AsyncGenerator[str, None]: + request_id, model = None, None + for chunk in response: + if not isinstance(chunk, LLMResultChunk): + yield "[ERROR]" + return + + if model is None: + model = chunk.model + if request_id is None: + request_id = "request_id" + yield create_stream_chunk(request_id, model, ChatCompletionMessage(role=Role.ASSISTANT, content="")) + + new_token = chunk.delta.message.content + + if new_token: + delta = ChatCompletionMessage(role=Role.value_of(chunk.delta.message.role.to_origin_role()), + content=new_token, + tool_calls=chunk.delta.message.tool_calls) + yield create_stream_chunk(request_id=request_id, + model=model, delta=delta, + index=chunk.delta.index, + finish_reason=chunk.delta.finish_reason) + + yield create_stream_chunk(request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP) + yield "[DONE]" + + +async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse: + choice = ChatCompletionResponseChoice( + index=0, message=ChatCompletionMessage(**_convert_prompt_message_to_dict(message=response.message)), + finish_reason=Finish.STOP + ) + usage = UsageInfo( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + ) + return ChatCompletionResponse( + id="request_id", + model=response.model, + choices=[choice], + usage=usage, + ) + class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): """ @@ -143,7 +335,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return ModelList(data=models_list) async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" @@ -153,7 +345,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" @@ -162,38 +354,47 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): model_instance = self._provider_manager.get_model_instance( provider=provider, model_type=ModelType.LLM, model=chat_request.model ) - if chat_request.stream: - # Invoke model + prompt_messages = [_convert_to_message(message) for message in chat_request.messages] + + tools = [PromptMessageTool(name=f.function.name, + description=f.function.description, + parameters=f.function.parameters + ) + + for f in chat_request.tools] + if chat_request.functions: + tools.extend([PromptMessageTool(name=f.name, + description=f.description, + parameters=f.parameters + ) for f in chat_request.functions]) + + try: response = model_instance.invoke_llm( - prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], + prompt_messages=prompt_messages, model_parameters={**chat_request.to_model_parameters_dict()}, + tools=tools, stop=chat_request.stop, stream=chat_request.stream, user="abc-123", ) - return EventSourceResponse(response, media_type="text/event-stream") - else: - # Invoke model + if chat_request.stream: - response = model_instance.invoke_llm( - prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], - model_parameters={**chat_request.to_model_parameters_dict()}, - stop=chat_request.stop, - stream=chat_request.stream, - user="abc-123", - ) + return EventSourceResponse(_stream_openai_chat_completion(response), media_type="text/event-stream") + else: + return await _openai_chat_completion(response) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) - chat_response = ChatCompletionResponse(**dictify(response)) - - return chat_response + except InvokeError as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 2753ad5d..ec5ddc3f 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -13,12 +13,62 @@ class Role(str, Enum): FUNCTION = "function" TOOL = "tool" + @classmethod + def value_of(cls, origin_role: str) -> "Role": + if origin_role == "user": + return cls.USER + elif origin_role == "assistant": + return cls.ASSISTANT + elif origin_role == "system": + return cls.SYSTEM + elif origin_role == "function": + return cls.FUNCTION + elif origin_role == "tool": + return cls.TOOL + else: + raise ValueError(f"invalid origin role {origin_role}") + + def to_origin_role(self) -> str: + if self == self.USER: + return "user" + elif self == self.ASSISTANT: + return "assistant" + elif self == self.SYSTEM: + return "system" + elif self == self.FUNCTION: + return "function" + elif self == self.TOOL: + return "tool" + else: + raise ValueError(f"invalid role {self}") + class Finish(str, Enum): STOP = "stop" LENGTH = "length" TOOL = "tool_calls" + @classmethod + def value_of(cls, origin_finish: str) -> "Finish": + if origin_finish == "stop": + return cls.STOP + elif origin_finish == "length": + return cls.LENGTH + elif origin_finish == "tool_calls": + return cls.TOOL + else: + raise ValueError(f"invalid origin finish {origin_finish}") + + def to_origin_finish(self) -> str: + if self == self.STOP: + return "stop" + elif self == self.LENGTH: + return "length" + elif self == self.TOOL: + return "tool_calls" + else: + raise ValueError(f"invalid finish {self}") + class ModelCard(BaseModel): id: str @@ -95,7 +145,7 @@ class ChatCompletionRequest(BaseModel): top_k: Optional[float] = None n: int = 1 max_tokens: Optional[int] = None - stop: Optional[list[str]] = (None,) + stop: Optional[list[str]] = None stream: Optional[bool] = False def to_model_parameters_dict(self, *args, **kwargs): diff --git a/model-providers/model_providers/core/model_runtime/entities/message_entities.py b/model-providers/model_providers/core/model_runtime/entities/message_entities.py index c9a823c0..a66294ad 100644 --- a/model-providers/model_providers/core/model_runtime/entities/message_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/message_entities.py @@ -28,6 +28,23 @@ class PromptMessageRole(Enum): return mode raise ValueError(f"invalid prompt message type value {value}") + def to_origin_role(self) -> str: + """ + Get origin role from prompt message role. + + :return: origin role + """ + if self == self.SYSTEM: + return "system" + elif self == self.USER: + return "user" + elif self == self.ASSISTANT: + return "assistant" + elif self == self.TOOL: + return "tool" + else: + raise ValueError(f"invalid role {self}") + class PromptMessageTool(BaseModel): """