mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-05 06:03:14 +08:00
parent
b3dee0b1d1
commit
5169228b86
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
@ -5,13 +6,14 @@ from typing import Optional, Union
|
||||
import google.api_core.exceptions as exceptions
|
||||
import google.generativeai as genai
|
||||
import google.generativeai.client as client
|
||||
from google.ai.generativelanguage_v1beta import FunctionCall, FunctionResponse
|
||||
from google.generativeai.types import (
|
||||
ContentType,
|
||||
GenerateContentResponse,
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
)
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from google.generativeai.types.content_types import to_part, FunctionDeclaration, Tool, FunctionLibrary
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@ -56,15 +58,15 @@ if you are not sure about the structure.
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -81,15 +83,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(
|
||||
model, credentials, prompt_messages, model_parameters, stop, stream, user
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -138,14 +140,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -160,9 +163,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop(
|
||||
config_kwargs.pop(
|
||||
"max_tokens_to_sample", None
|
||||
)
|
||||
# https://github.com/google/generative-ai-python/issues/170
|
||||
# config_kwargs["max_output_tokens"] = config_kwargs.pop(
|
||||
# "max_tokens_to_sample", None
|
||||
# )
|
||||
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
@ -197,12 +204,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
tools_one = []
|
||||
for tool in tools:
|
||||
one_tool = Tool(function_declarations=[FunctionDeclaration(name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters
|
||||
)
|
||||
])
|
||||
tools_one.append(one_tool)
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||
stream=stream,
|
||||
safety_settings=safety_settings,
|
||||
tools=FunctionLibrary(tools=tools_one),
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -215,11 +231,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: GenerateContentResponse,
|
||||
prompt_messages: list[PromptMessage],
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: GenerateContentResponse,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@ -230,8 +246,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
part = response.candidates[0].content.parts[0]
|
||||
part_message_function_call = part.function_call
|
||||
tool_calls = []
|
||||
if part_message_function_call:
|
||||
function_call = self._extract_response_function_call(
|
||||
part_message_function_call
|
||||
)
|
||||
tool_calls.append(function_call)
|
||||
part_message_function_response = part.function_response
|
||||
if part_message_function_response:
|
||||
function_call = self._extract_response_function_call(
|
||||
part_message_function_call
|
||||
)
|
||||
tool_calls.append(function_call)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=response.text)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=part.text, tool_calls=tool_calls)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
@ -255,11 +286,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: GenerateContentResponse,
|
||||
prompt_messages: list[PromptMessage],
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: GenerateContentResponse,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -413,3 +444,37 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
exceptions.Cancelled,
|
||||
],
|
||||
}
|
||||
|
||||
def _extract_response_function_call(
|
||||
self, response_function_call: Union[FunctionCall, FunctionResponse]
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract function call from response
|
||||
|
||||
:param response_function_call: response function call
|
||||
:return: tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
from google.protobuf import json_format
|
||||
|
||||
if isinstance(response_function_call, FunctionCall):
|
||||
map_composite_dict = dict(response_function_call.args.items())
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call.name,
|
||||
arguments=str(map_composite_dict),
|
||||
)
|
||||
elif isinstance(response_function_call, FunctionResponse):
|
||||
map_composite_dict = dict(response_function_call.response.items())
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call.name,
|
||||
arguments=str(map_composite_dict),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported response_function_call type: {type(response_function_call)}")
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_function_call.name, type="function", function=function
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user