diff --git a/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py index b3a08d87..f06df8f7 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py @@ -1,3 +1,4 @@ +import json import logging from collections.abc import Generator from typing import Optional, Union @@ -5,13 +6,14 @@ from typing import Optional, Union import google.api_core.exceptions as exceptions import google.generativeai as genai import google.generativeai.client as client +from google.ai.generativelanguage_v1beta import FunctionCall, FunctionResponse from google.generativeai.types import ( ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory, ) -from google.generativeai.types.content_types import to_part +from google.generativeai.types.content_types import to_part, FunctionDeclaration, Tool, FunctionLibrary from model_providers.core.model_runtime.entities.llm_entities import ( LLMResult, @@ -56,15 +58,15 @@ if you are not sure about the structure. class GoogleLargeLanguageModel(LargeLanguageModel): def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -81,15 +83,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate( - model, credentials, prompt_messages, model_parameters, stop, stream, user + model, credentials, prompt_messages, model_parameters, tools, stop, stream, user ) def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, ) -> int: """ Get number of tokens for given prompt messages @@ -138,14 +140,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -160,9 +163,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs["max_output_tokens"] = config_kwargs.pop( + config_kwargs.pop( "max_tokens_to_sample", None ) + # https://github.com/google/generative-ai-python/issues/170 + # config_kwargs["max_output_tokens"] = config_kwargs.pop( + # "max_tokens_to_sample", None + # ) if stop: config_kwargs["stop_sequences"] = stop @@ -197,12 +204,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel): HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } + tools_one = [] + for tool in tools: + one_tool = Tool(function_declarations=[FunctionDeclaration(name=tool.name, + description=tool.description, + parameters=tool.parameters + ) + ]) + tools_one.append(one_tool) response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig(**config_kwargs), stream=stream, safety_settings=safety_settings, + tools=FunctionLibrary(tools=tools_one), ) if stream: @@ -215,11 +231,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ) def _handle_generate_response( - self, - model: str, - credentials: dict, - response: GenerateContentResponse, - prompt_messages: list[PromptMessage], + self, + model: str, + credentials: dict, + response: GenerateContentResponse, + prompt_messages: list[PromptMessage], ) -> LLMResult: """ Handle llm response @@ -230,8 +246,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ + part = response.candidates[0].content.parts[0] + part_message_function_call = part.function_call + tool_calls = [] + if part_message_function_call: + function_call = self._extract_response_function_call( + part_message_function_call + ) + tool_calls.append(function_call) + part_message_function_response = part.function_response + if part_message_function_response: + function_call = self._extract_response_function_call( + part_message_function_call + ) + tool_calls.append(function_call) + # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=response.text) + assistant_prompt_message = AssistantPromptMessage(content=part.text, tool_calls=tool_calls) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -255,11 +286,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return result def _handle_generate_stream_response( - self, - model: str, - credentials: dict, - response: GenerateContentResponse, - prompt_messages: list[PromptMessage], + self, + model: str, + credentials: dict, + response: GenerateContentResponse, + prompt_messages: list[PromptMessage], ) -> Generator: """ Handle llm stream response @@ -413,3 +444,37 @@ class GoogleLargeLanguageModel(LargeLanguageModel): exceptions.Cancelled, ], } + + def _extract_response_function_call( + self, response_function_call: Union[FunctionCall, FunctionResponse] + ) -> AssistantPromptMessage.ToolCall: + """ + Extract function call from response + + :param response_function_call: response function call + :return: tool call + """ + tool_call = None + if response_function_call: + from google.protobuf import json_format + + if isinstance(response_function_call, FunctionCall): + map_composite_dict = dict(response_function_call.args.items()) + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call.name, + arguments=str(map_composite_dict), + ) + elif isinstance(response_function_call, FunctionResponse): + map_composite_dict = dict(response_function_call.response.items()) + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call.name, + arguments=str(map_composite_dict), + ) + else: + raise ValueError(f"Unsupported response_function_call type: {type(response_function_call)}") + + tool_call = AssistantPromptMessage.ToolCall( + id=response_function_call.name, type="function", function=function + ) + + return tool_call