Dev model providers (#3628)

* gemini 初始化参数问题

* gemini 同步工具调用
This commit is contained in:
glide-the 2024-04-06 23:25:33 +08:00 committed by GitHub
parent b3dee0b1d1
commit 5169228b86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union
@ -5,13 +6,14 @@ from typing import Optional, Union
import google.api_core.exceptions as exceptions
import google.generativeai as genai
import google.generativeai.client as client
from google.ai.generativelanguage_v1beta import FunctionCall, FunctionResponse
from google.generativeai.types import (
ContentType,
GenerateContentResponse,
HarmBlockThreshold,
HarmCategory,
)
from google.generativeai.types.content_types import to_part
from google.generativeai.types.content_types import to_part, FunctionDeclaration, Tool, FunctionLibrary
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@ -56,15 +58,15 @@ if you are not sure about the structure.
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -81,15 +83,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(
model, credentials, prompt_messages, model_parameters, stop, stream, user
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -138,14 +140,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex))
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -160,9 +163,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop(
config_kwargs.pop(
"max_tokens_to_sample", None
)
# https://github.com/google/generative-ai-python/issues/170
# config_kwargs["max_output_tokens"] = config_kwargs.pop(
# "max_tokens_to_sample", None
# )
if stop:
config_kwargs["stop_sequences"] = stop
@ -197,12 +204,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
tools_one = []
for tool in tools:
one_tool = Tool(function_declarations=[FunctionDeclaration(name=tool.name,
description=tool.description,
parameters=tool.parameters
)
])
tools_one.append(one_tool)
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
stream=stream,
safety_settings=safety_settings,
tools=FunctionLibrary(tools=tools_one),
)
if stream:
@ -215,11 +231,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
)
def _handle_generate_response(
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -230,8 +246,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response
"""
part = response.candidates[0].content.parts[0]
part_message_function_call = part.function_call
tool_calls = []
if part_message_function_call:
function_call = self._extract_response_function_call(
part_message_function_call
)
tool_calls.append(function_call)
part_message_function_response = part.function_response
if part_message_function_response:
function_call = self._extract_response_function_call(
part_message_function_call
)
tool_calls.append(function_call)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=response.text)
assistant_prompt_message = AssistantPromptMessage(content=part.text, tool_calls=tool_calls)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@ -255,11 +286,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -413,3 +444,37 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
exceptions.Cancelled,
],
}
def _extract_response_function_call(
self, response_function_call: Union[FunctionCall, FunctionResponse]
) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
:param response_function_call: response function call
:return: tool call
"""
tool_call = None
if response_function_call:
from google.protobuf import json_format
if isinstance(response_function_call, FunctionCall):
map_composite_dict = dict(response_function_call.args.items())
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name,
arguments=str(map_composite_dict),
)
elif isinstance(response_function_call, FunctionResponse):
map_composite_dict = dict(response_function_call.response.items())
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name,
arguments=str(map_composite_dict),
)
else:
raise ValueError(f"Unsupported response_function_call type: {type(response_function_call)}")
tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name, type="function", function=function
)
return tool_call