diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py index 99b6fd1a..ecde3e4a 100644 --- a/model-providers/model_providers/__main__.py +++ b/model-providers/model_providers/__main__.py @@ -3,11 +3,15 @@ import asyncio import logging from model_providers import BootstrapWebBuilder -from model_providers.core.utils.utils import get_config_dict, get_log_file, get_timestamp_ms +from model_providers.core.utils.utils import ( + get_config_dict, + get_log_file, + get_timestamp_ms, +) logger = logging.getLogger(__name__) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model-providers", @@ -26,9 +30,7 @@ if __name__ == '__main__': ) boot = ( BootstrapWebBuilder() - .model_providers_cfg_path( - model_providers_cfg_path=args.model_providers - ) + .model_providers_cfg_path(model_providers_cfg_path=args.model_providers) .host(host="127.0.0.1") .port(port=20000) .build() @@ -36,11 +38,9 @@ if __name__ == '__main__': boot.set_app_event(started_event=None) boot.serve(logging_conf=logging_conf) - async def pool_join_thread(): await boot.join() - asyncio.run(pool_join_thread()) except SystemExit: logger.info("SystemExit raised, exiting") diff --git a/model-providers/model_providers/bootstrap_web/common.py b/model-providers/model_providers/bootstrap_web/common.py deleted file mode 100644 index a06a3064..00000000 --- a/model-providers/model_providers/bootstrap_web/common.py +++ /dev/null @@ -1,27 +0,0 @@ -import typing -from subprocess import Popen -from typing import Optional - -from model_providers.core.bootstrap.openai_protocol import ( - ChatCompletionStreamResponse, - ChatCompletionStreamResponseChoice, - Finish, -) -from model_providers.core.utils.generic import jsonify - -if typing.TYPE_CHECKING: - from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage - - -def create_stream_chunk( - request_id: str, - model: str, - delta: "ChatCompletionMessage", - index: Optional[int] = 0, - finish_reason: Optional[Finish] = None, -) -> str: - choice = ChatCompletionStreamResponseChoice( - index=index, delta=delta, finish_reason=finish_reason - ) - chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice]) - return jsonify(chunk) diff --git a/model-providers/model_providers/bootstrap_web/message_convert/__init__.py b/model-providers/model_providers/bootstrap_web/message_convert/__init__.py new file mode 100644 index 00000000..61fc5eda --- /dev/null +++ b/model-providers/model_providers/bootstrap_web/message_convert/__init__.py @@ -0,0 +1,13 @@ +from model_providers.bootstrap_web.message_convert.core import ( + convert_to_message, + openai_chat_completion, + openai_embedding_text, + stream_openai_chat_completion, +) + +__all__ = [ + "convert_to_message", + "stream_openai_chat_completion", + "openai_chat_completion", + "openai_embedding_text", +] diff --git a/model-providers/model_providers/bootstrap_web/message_convert/core.py b/model-providers/model_providers/bootstrap_web/message_convert/core.py new file mode 100644 index 00000000..13fc2bbc --- /dev/null +++ b/model-providers/model_providers/bootstrap_web/message_convert/core.py @@ -0,0 +1,289 @@ +import logging +import typing +from typing import ( + Any, + AsyncGenerator, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) + +from model_providers.core.bootstrap.openai_protocol import ( + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, + ChatMessage, + Embeddings, + EmbeddingsResponse, + Finish, + Role, + UsageInfo, +) +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, +) +from model_providers.core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + TextEmbeddingResult, +) +from model_providers.core.utils.generic import jsonify + +if typing.TYPE_CHECKING: + from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage + +logger = logging.getLogger(__name__) + +MessageLike = Union[ChatMessage, PromptMessage] + +MessageLikeRepresentation = Union[ + MessageLike, + Tuple[Union[str, Type], Union[str, List[dict], List[object]]], + str, +] + + +def create_stream_chunk( + request_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional[Finish] = None, +) -> str: + choice = ChatCompletionStreamResponseChoice( + index=index, delta=delta, finish_reason=finish_reason + ) + chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice]) + return jsonify(chunk) + + +def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI Compatibility API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + raise ValueError("User message content must be str") + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls and len(message.tool_calls) > 0: + message_dict["function_call"] = { + "name": message.tool_calls[0].function.name, + "arguments": message.tool_calls[0].function.arguments, + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + # check if last message is user message + message = cast(ToolPromptMessage, message) + message_dict = {"role": "function", "content": message.content} + else: + raise ValueError(f"Unknown message type {type(message)}") + + return message_dict + + +def _create_template_from_message_type( + message_type: str, template: Union[str, list] +) -> PromptMessage: + """Create a message prompt template from a message type and template string. + + Args: + message_type: str the type of the message template (e.g., "human", "ai", etc.) + template: str the template string. + + Returns: + a message prompt template of the appropriate type. + """ + if isinstance(template, str): + content = template + elif isinstance(template, list): + content = [] + for tmpl in template: + if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl: + if isinstance(tmpl, str): + text: str = tmpl + else: + text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501 + content.append(TextPromptMessageContent(data=text)) + elif isinstance(tmpl, dict) and "image_url" in tmpl: + img_template = cast(dict, tmpl)["image_url"] + if isinstance(img_template, str): + img_template_obj = ImagePromptMessageContent(data=img_template) + elif isinstance(img_template, dict): + img_template = dict(img_template) + if "url" in img_template: + url = img_template["url"] + else: + url = None + img_template_obj = ImagePromptMessageContent(data=url) + else: + raise ValueError() + content.append(img_template_obj) + else: + raise ValueError() + else: + raise ValueError() + + if message_type in ("human", "user"): + _message = UserPromptMessage(content=content) + elif message_type in ("ai", "assistant"): + _message = AssistantPromptMessage(content=content) + elif message_type == "system": + _message = SystemPromptMessage(content=content) + elif message_type in ("function", "tool"): + _message = ToolPromptMessage(content=content) + else: + raise ValueError( + f"Unexpected message type: {message_type}. Use one of 'human'," + f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'." + ) + + return _message + + +def _convert_to_message( + message: MessageLikeRepresentation, +) -> Union[PromptMessage]: + """Instantiate a message from a variety of message formats. + + The message format can be one of the following: + + - BaseMessagePromptTemplate + - BaseMessage + - 2-tuple of (role string, template); e.g., ("human", "{user_input}") + - 2-tuple of (message class, template) + - string: shorthand for ("human", template); e.g., "{user_input}" + + Args: + message: a representation of a message in one of the supported formats + + Returns: + an instance of a message or a message template + """ + if isinstance(message, ChatMessage): + _message = _create_template_from_message_type( + message.role.to_origin_role(), message.content + ) + + elif isinstance(message, PromptMessage): + _message = message + elif isinstance(message, str): + _message = _create_template_from_message_type("human", message) + elif isinstance(message, tuple): + if len(message) != 2: + raise ValueError(f"Expected 2-tuple of (role, template), got {message}") + message_type_str, template = message + if isinstance(message_type_str, str): + _message = _create_template_from_message_type(message_type_str, template) + else: + raise ValueError(f"Expected message type string, got {message_type_str}") + else: + raise NotImplementedError(f"Unsupported message type: {type(message)}") + + return _message + + +async def _stream_openai_chat_completion( + response: Generator, +) -> AsyncGenerator[str, None]: + request_id, model = None, None + for chunk in response: + if not isinstance(chunk, LLMResultChunk): + yield "[ERROR]" + return + + if model is None: + model = chunk.model + if request_id is None: + request_id = "request_id" + yield create_stream_chunk( + request_id, + model, + ChatCompletionMessage(role=Role.ASSISTANT, content=""), + ) + + new_token = chunk.delta.message.content + + if new_token: + delta = ChatCompletionMessage( + role=Role.value_of(chunk.delta.message.role.to_origin_role()), + content=new_token, + tool_calls=chunk.delta.message.tool_calls, + ) + yield create_stream_chunk( + request_id=request_id, + model=model, + delta=delta, + index=chunk.delta.index, + finish_reason=chunk.delta.finish_reason, + ) + + yield create_stream_chunk( + request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP + ) + yield "[DONE]" + + +async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse: + choice = ChatCompletionResponseChoice( + index=0, + message=ChatCompletionMessage( + **_convert_prompt_message_to_dict(message=response.message) + ), + finish_reason=Finish.STOP, + ) + usage = UsageInfo( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + ) + return ChatCompletionResponse( + id="request_id", + model=response.model, + choices=[choice], + usage=usage, + ) + + +async def _openai_embedding_text(response: TextEmbeddingResult) -> EmbeddingsResponse: + embedding = [ + Embeddings(embedding=embedding, index=index) + for index, embedding in enumerate(response.embeddings) + ] + + return EmbeddingsResponse( + model=response.model, + data=embedding, + usage=UsageInfo( + prompt_tokens=response.usage.tokens, + total_tokens=response.usage.total_tokens, + completion_tokens=response.usage.total_tokens, + ), + ) + + +convert_to_message = _convert_to_message +stream_openai_chat_completion = _stream_openai_chat_completion +openai_chat_completion = _openai_chat_completion +openai_embedding_text = _openai_embedding_text diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 39d91570..3c3e91af 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -23,256 +23,38 @@ from fastapi.middleware.cors import CORSMiddleware from sse_starlette import EventSourceResponse from uvicorn import Config, Server -from model_providers.bootstrap_web.common import create_stream_chunk from model_providers.bootstrap_web.entities.model_provider_entities import ( ProviderListResponse, ProviderModelTypeResponse, ) +from model_providers.bootstrap_web.message_convert import ( + convert_to_message, + openai_chat_completion, + openai_embedding_text, + stream_openai_chat_completion, +) from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap.openai_protocol import ( - ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionStreamResponse, - ChatMessage, EmbeddingsRequest, EmbeddingsResponse, - Finish, - FunctionAvailable, ModelCard, ModelList, - Role, - UsageInfo, ) from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper -from model_providers.core.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, -) from model_providers.core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContent, - PromptMessageContentType, PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, ) from model_providers.core.model_runtime.entities.model_entities import ( AIModelEntity, ModelType, ) from model_providers.core.model_runtime.errors.invoke import InvokeError -from model_providers.core.utils.generic import dictify, jsonify +from model_providers.core.utils.generic import dictify logger = logging.getLogger(__name__) -MessageLike = Union[ChatMessage, PromptMessage] - -MessageLikeRepresentation = Union[ - MessageLike, - Tuple[Union[str, Type], Union[str, List[dict], List[object]]], - str, -] - - -def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: - """ - Convert PromptMessage to dict for OpenAI Compatibility API - """ - if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) - if isinstance(message.content, str): - message_dict = {"role": "user", "content": message.content} - else: - raise ValueError("User message content must be str") - elif isinstance(message, AssistantPromptMessage): - message = cast(AssistantPromptMessage, message) - message_dict = {"role": "assistant", "content": message.content} - if message.tool_calls and len(message.tool_calls) > 0: - message_dict["function_call"] = { - "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments, - } - elif isinstance(message, SystemPromptMessage): - message = cast(SystemPromptMessage, message) - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, ToolPromptMessage): - # check if last message is user message - message = cast(ToolPromptMessage, message) - message_dict = {"role": "function", "content": message.content} - else: - raise ValueError(f"Unknown message type {type(message)}") - - return message_dict - - -def _create_template_from_message_type( - message_type: str, template: Union[str, list] -) -> PromptMessage: - """Create a message prompt template from a message type and template string. - - Args: - message_type: str the type of the message template (e.g., "human", "ai", etc.) - template: str the template string. - - Returns: - a message prompt template of the appropriate type. - """ - if isinstance(template, str): - content = template - elif isinstance(template, list): - content = [] - for tmpl in template: - if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl: - if isinstance(tmpl, str): - text: str = tmpl - else: - text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501 - content.append(TextPromptMessageContent(data=text)) - elif isinstance(tmpl, dict) and "image_url" in tmpl: - img_template = cast(dict, tmpl)["image_url"] - if isinstance(img_template, str): - img_template_obj = ImagePromptMessageContent(data=img_template) - elif isinstance(img_template, dict): - img_template = dict(img_template) - if "url" in img_template: - url = img_template["url"] - else: - url = None - img_template_obj = ImagePromptMessageContent(data=url) - else: - raise ValueError() - content.append(img_template_obj) - else: - raise ValueError() - else: - raise ValueError() - - if message_type in ("human", "user"): - _message = UserPromptMessage(content=content) - elif message_type in ("ai", "assistant"): - _message = AssistantPromptMessage(content=content) - elif message_type == "system": - _message = SystemPromptMessage(content=content) - elif message_type in ("function", "tool"): - _message = ToolPromptMessage(content=content) - else: - raise ValueError( - f"Unexpected message type: {message_type}. Use one of 'human'," - f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'." - ) - - return _message - - -def _convert_to_message( - message: MessageLikeRepresentation, -) -> Union[PromptMessage]: - """Instantiate a message from a variety of message formats. - - The message format can be one of the following: - - - BaseMessagePromptTemplate - - BaseMessage - - 2-tuple of (role string, template); e.g., ("human", "{user_input}") - - 2-tuple of (message class, template) - - string: shorthand for ("human", template); e.g., "{user_input}" - - Args: - message: a representation of a message in one of the supported formats - - Returns: - an instance of a message or a message template - """ - if isinstance(message, ChatMessage): - _message = _create_template_from_message_type( - message.role.to_origin_role(), message.content - ) - - elif isinstance(message, PromptMessage): - _message = message - elif isinstance(message, str): - _message = _create_template_from_message_type("human", message) - elif isinstance(message, tuple): - if len(message) != 2: - raise ValueError(f"Expected 2-tuple of (role, template), got {message}") - message_type_str, template = message - if isinstance(message_type_str, str): - _message = _create_template_from_message_type(message_type_str, template) - else: - raise ValueError(f"Expected message type string, got {message_type_str}") - else: - raise NotImplementedError(f"Unsupported message type: {type(message)}") - - return _message - - -async def _stream_openai_chat_completion( - response: Generator, -) -> AsyncGenerator[str, None]: - request_id, model = None, None - for chunk in response: - if not isinstance(chunk, LLMResultChunk): - yield "[ERROR]" - return - - if model is None: - model = chunk.model - if request_id is None: - request_id = "request_id" - yield create_stream_chunk( - request_id, - model, - ChatCompletionMessage(role=Role.ASSISTANT, content=""), - ) - - new_token = chunk.delta.message.content - - if new_token: - delta = ChatCompletionMessage( - role=Role.value_of(chunk.delta.message.role.to_origin_role()), - content=new_token, - tool_calls=chunk.delta.message.tool_calls, - ) - yield create_stream_chunk( - request_id=request_id, - model=model, - delta=delta, - index=chunk.delta.index, - finish_reason=chunk.delta.finish_reason, - ) - - yield create_stream_chunk( - request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP - ) - yield "[DONE]" - - -async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse: - choice = ChatCompletionResponseChoice( - index=0, - message=ChatCompletionMessage( - **_convert_prompt_message_to_dict(message=response.message) - ), - finish_reason=Finish.STOP, - ) - usage = UsageInfo( - prompt_tokens=response.usage.prompt_tokens, - completion_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens, - ) - return ChatCompletionResponse( - id="request_id", - model=response.model, - choices=[choice], - usage=usage, - ) - class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): """ @@ -363,14 +145,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): started_event.set() async def workspaces_model_providers(self, request: Request): - - provider_list = ProvidersWrapper(provider_manager=self._provider_manager.provider_manager).get_provider_list( - model_type=request.get("model_type")) + provider_list = ProvidersWrapper( + provider_manager=self._provider_manager.provider_manager + ).get_provider_list(model_type=request.get("model_type")) return ProviderListResponse(data=provider_list) async def workspaces_model_types(self, model_type: str, request: Request): models_by_model_type = ProvidersWrapper( - provider_manager=self._provider_manager.provider_manager).get_models_by_model_type(model_type=model_type) + provider_manager=self._provider_manager.provider_manager + ).get_models_by_model_type(model_type=model_type) return ProviderModelTypeResponse(data=models_by_model_type) async def list_models(self, provider: str, request: Request): @@ -403,17 +186,24 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return ModelList(data=models_list) async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" ) - - response = None - return EmbeddingsResponse(**dictify(response)) + model_instance = self._provider_manager.get_model_instance( + provider=provider, + model_type=ModelType.TEXT_EMBEDDING, + model=embeddings_request.model, + ) + texts = embeddings_request.input + if isinstance(texts, str): + texts = [texts] + response = model_instance.invoke_text_embedding(texts=texts, user="abc-123") + return await openai_embedding_text(response) async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" @@ -423,7 +213,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): provider=provider, model_type=ModelType.LLM, model=chat_request.model ) prompt_messages = [ - _convert_to_message(message) for message in chat_request.messages + convert_to_message(message) for message in chat_request.messages ] tools = [] @@ -458,11 +248,11 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): if chat_request.stream: return EventSourceResponse( - _stream_openai_chat_completion(response), + stream_openai_chat_completion(response), media_type="text/event-stream", ) else: - return await _openai_chat_completion(response) + return await openai_chat_completion(response) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) @@ -473,9 +263,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: diff --git a/model-providers/model_providers/core/bootstrap/providers_wapper.py b/model-providers/model_providers/core/bootstrap/providers_wapper.py index 8a2af953..d958a999 100644 --- a/model-providers/model_providers/core/bootstrap/providers_wapper.py +++ b/model-providers/model_providers/core/bootstrap/providers_wapper.py @@ -1,5 +1,4 @@ -from typing import Optional, List - +from typing import List, Optional from model_providers.bootstrap_web.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -9,10 +8,8 @@ from model_providers.bootstrap_web.entities.model_provider_entities import ( ProviderWithModelsResponse, SystemConfigurationResponse, ) - from model_providers.core.entities.model_entities import ModelStatus from model_providers.core.entities.provider_entities import ProviderType - from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.provider_manager import ProviderManager @@ -22,7 +19,7 @@ class ProvidersWrapper: self.provider_manager = provider_manager def get_provider_list( - self, model_type: Optional[str] = None + self, model_type: Optional[str] = None ) -> List[ProviderResponse]: """ get provider list. @@ -38,8 +35,8 @@ class ProvidersWrapper: self.provider_manager.provider_name_to_provider_model_records_dict.keys() ) # Get all provider configurations of the current workspace - provider_configurations = ( - self.provider_manager.get_configurations(provider=provider) + provider_configurations = self.provider_manager.get_configurations( + provider=provider ) provider_responses = [] @@ -47,8 +44,8 @@ class ProvidersWrapper: if model_type: model_type_entity = ModelType.value_of(model_type) if ( - model_type_entity - not in provider_configuration.provider.supported_model_types + model_type_entity + not in provider_configuration.provider.supported_model_types ): continue @@ -78,7 +75,7 @@ class ProvidersWrapper: return provider_responses def get_models_by_model_type( - self, model_type: str + self, model_type: str ) -> List[ProviderWithModelsResponse]: """ get models by model type. @@ -94,8 +91,8 @@ class ProvidersWrapper: self.provider_manager.provider_name_to_provider_model_records_dict.keys() ) # Get all provider configurations of the current workspace - provider_configurations = ( - self.provider_manager.get_configurations(provider=provider) + provider_configurations = self.provider_manager.get_configurations( + provider=provider ) # Get provider available models diff --git a/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py index f06df8f7..a23df8aa 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py @@ -13,7 +13,12 @@ from google.generativeai.types import ( HarmBlockThreshold, HarmCategory, ) -from google.generativeai.types.content_types import to_part, FunctionDeclaration, Tool, FunctionLibrary +from google.generativeai.types.content_types import ( + FunctionDeclaration, + FunctionLibrary, + Tool, + to_part, +) from model_providers.core.model_runtime.entities.llm_entities import ( LLMResult, @@ -58,15 +63,15 @@ if you are not sure about the structure. class GoogleLargeLanguageModel(LargeLanguageModel): def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -83,15 +88,22 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate( - model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, ) def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, ) -> int: """ Get number of tokens for given prompt messages @@ -140,15 +152,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -163,9 +175,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs.pop( - "max_tokens_to_sample", None - ) + config_kwargs.pop("max_tokens_to_sample", None) # https://github.com/google/generative-ai-python/issues/170 # config_kwargs["max_output_tokens"] = config_kwargs.pop( # "max_tokens_to_sample", None @@ -206,11 +216,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel): } tools_one = [] for tool in tools: - one_tool = Tool(function_declarations=[FunctionDeclaration(name=tool.name, - description=tool.description, - parameters=tool.parameters - ) - ]) + one_tool = Tool( + function_declarations=[ + FunctionDeclaration( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) + ] + ) tools_one.append(one_tool) response = google_model.generate_content( @@ -231,11 +245,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ) def _handle_generate_response( - self, - model: str, - credentials: dict, - response: GenerateContentResponse, - prompt_messages: list[PromptMessage], + self, + model: str, + credentials: dict, + response: GenerateContentResponse, + prompt_messages: list[PromptMessage], ) -> LLMResult: """ Handle llm response @@ -262,7 +276,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): tool_calls.append(function_call) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=part.text, tool_calls=tool_calls) + assistant_prompt_message = AssistantPromptMessage( + content=part.text, tool_calls=tool_calls + ) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -286,11 +302,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return result def _handle_generate_stream_response( - self, - model: str, - credentials: dict, - response: GenerateContentResponse, - prompt_messages: list[PromptMessage], + self, + model: str, + credentials: dict, + response: GenerateContentResponse, + prompt_messages: list[PromptMessage], ) -> Generator: """ Handle llm stream response @@ -446,7 +462,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): } def _extract_response_function_call( - self, response_function_call: Union[FunctionCall, FunctionResponse] + self, response_function_call: Union[FunctionCall, FunctionResponse] ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -471,7 +487,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel): arguments=str(map_composite_dict), ) else: - raise ValueError(f"Unsupported response_function_call type: {type(response_function_call)}") + raise ValueError( + f"Unsupported response_function_call type: {type(response_function_call)}" + ) tool_call = AssistantPromptMessage.ToolCall( id=response_function_call.name, type="function", function=function