embedding convert endpoint

This commit is contained in:
glide-the 2024-04-07 16:27:20 +08:00
parent 5169228b86
commit 051acfbeae
7 changed files with 412 additions and 332 deletions

View File

@ -3,11 +3,15 @@ import asyncio
import logging
from model_providers import BootstrapWebBuilder
from model_providers.core.utils.utils import get_config_dict, get_log_file, get_timestamp_ms
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
logger = logging.getLogger(__name__)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-providers",
@ -26,9 +30,7 @@ if __name__ == '__main__':
)
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path=args.model_providers
)
.model_providers_cfg_path(model_providers_cfg_path=args.model_providers)
.host(host="127.0.0.1")
.port(port=20000)
.build()
@ -36,11 +38,9 @@ if __name__ == '__main__':
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
async def pool_join_thread():
await boot.join()
asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")

View File

@ -1,27 +0,0 @@
import typing
from subprocess import Popen
from typing import Optional
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
)
from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING:
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
def create_stream_chunk(
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
) -> str:
choice = ChatCompletionStreamResponseChoice(
index=index, delta=delta, finish_reason=finish_reason
)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk)

View File

@ -0,0 +1,13 @@
from model_providers.bootstrap_web.message_convert.core import (
convert_to_message,
openai_chat_completion,
openai_embedding_text,
stream_openai_chat_completion,
)
__all__ = [
"convert_to_message",
"stream_openai_chat_completion",
"openai_chat_completion",
"openai_embedding_text",
]

View File

@ -0,0 +1,289 @@
import logging
import typing
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
ChatMessage,
Embeddings,
EmbeddingsResponse,
Finish,
Role,
UsageInfo,
)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
)
from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING:
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
logger = logging.getLogger(__name__)
MessageLike = Union[ChatMessage, PromptMessage]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]
def create_stream_chunk(
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
) -> str:
choice = ChatCompletionStreamResponseChoice(
index=index, delta=delta, finish_reason=finish_reason
)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk)
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if isinstance(template, str):
content = template
elif isinstance(template, list):
content = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append(TextPromptMessageContent(data=text))
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str):
img_template_obj = ImagePromptMessageContent(data=img_template)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
url = img_template["url"]
else:
url = None
img_template_obj = ImagePromptMessageContent(data=url)
else:
raise ValueError()
content.append(img_template_obj)
else:
raise ValueError()
else:
raise ValueError()
if message_type in ("human", "user"):
_message = UserPromptMessage(content=content)
elif message_type in ("ai", "assistant"):
_message = AssistantPromptMessage(content=content)
elif message_type == "system":
_message = SystemPromptMessage(content=content)
elif message_type in ("function", "tool"):
_message = ToolPromptMessage(content=content)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
)
return _message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(
message.role.to_origin_role(), message.content
)
elif isinstance(message, PromptMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
raise ValueError(f"Expected message type string, got {message_type_str}")
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
async def _stream_openai_chat_completion(
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
if not isinstance(chunk, LLMResultChunk):
yield "[ERROR]"
return
if model is None:
model = chunk.model
if request_id is None:
request_id = "request_id"
yield create_stream_chunk(
request_id,
model,
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
)
new_token = chunk.delta.message.content
if new_token:
delta = ChatCompletionMessage(
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls,
)
yield create_stream_chunk(
request_id=request_id,
model=model,
delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason,
)
yield create_stream_chunk(
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice(
index=0,
message=ChatCompletionMessage(
**_convert_prompt_message_to_dict(message=response.message)
),
finish_reason=Finish.STOP,
)
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ChatCompletionResponse(
id="request_id",
model=response.model,
choices=[choice],
usage=usage,
)
async def _openai_embedding_text(response: TextEmbeddingResult) -> EmbeddingsResponse:
embedding = [
Embeddings(embedding=embedding, index=index)
for index, embedding in enumerate(response.embeddings)
]
return EmbeddingsResponse(
model=response.model,
data=embedding,
usage=UsageInfo(
prompt_tokens=response.usage.tokens,
total_tokens=response.usage.total_tokens,
completion_tokens=response.usage.total_tokens,
),
)
convert_to_message = _convert_to_message
stream_openai_chat_completion = _stream_openai_chat_completion
openai_chat_completion = _openai_chat_completion
openai_embedding_text = _openai_embedding_text

View File

@ -23,256 +23,38 @@ from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.bootstrap_web.entities.model_provider_entities import (
ProviderListResponse,
ProviderModelTypeResponse,
)
from model_providers.bootstrap_web.message_convert import (
convert_to_message,
openai_chat_completion,
openai_embedding_text,
stream_openai_chat_completion,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatMessage,
EmbeddingsRequest,
EmbeddingsResponse,
Finish,
FunctionAvailable,
ModelCard,
ModelList,
Role,
UsageInfo,
)
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
)
from model_providers.core.model_runtime.errors.invoke import InvokeError
from model_providers.core.utils.generic import dictify, jsonify
from model_providers.core.utils.generic import dictify
logger = logging.getLogger(__name__)
MessageLike = Union[ChatMessage, PromptMessage]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if isinstance(template, str):
content = template
elif isinstance(template, list):
content = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append(TextPromptMessageContent(data=text))
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str):
img_template_obj = ImagePromptMessageContent(data=img_template)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
url = img_template["url"]
else:
url = None
img_template_obj = ImagePromptMessageContent(data=url)
else:
raise ValueError()
content.append(img_template_obj)
else:
raise ValueError()
else:
raise ValueError()
if message_type in ("human", "user"):
_message = UserPromptMessage(content=content)
elif message_type in ("ai", "assistant"):
_message = AssistantPromptMessage(content=content)
elif message_type == "system":
_message = SystemPromptMessage(content=content)
elif message_type in ("function", "tool"):
_message = ToolPromptMessage(content=content)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
)
return _message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(
message.role.to_origin_role(), message.content
)
elif isinstance(message, PromptMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
raise ValueError(f"Expected message type string, got {message_type_str}")
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
async def _stream_openai_chat_completion(
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
if not isinstance(chunk, LLMResultChunk):
yield "[ERROR]"
return
if model is None:
model = chunk.model
if request_id is None:
request_id = "request_id"
yield create_stream_chunk(
request_id,
model,
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
)
new_token = chunk.delta.message.content
if new_token:
delta = ChatCompletionMessage(
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls,
)
yield create_stream_chunk(
request_id=request_id,
model=model,
delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason,
)
yield create_stream_chunk(
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice(
index=0,
message=ChatCompletionMessage(
**_convert_prompt_message_to_dict(message=response.message)
),
finish_reason=Finish.STOP,
)
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ChatCompletionResponse(
id="request_id",
model=response.model,
choices=[choice],
usage=usage,
)
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
"""
@ -363,14 +145,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
started_event.set()
async def workspaces_model_providers(self, request: Request):
provider_list = ProvidersWrapper(provider_manager=self._provider_manager.provider_manager).get_provider_list(
model_type=request.get("model_type"))
provider_list = ProvidersWrapper(
provider_manager=self._provider_manager.provider_manager
).get_provider_list(model_type=request.get("model_type"))
return ProviderListResponse(data=provider_list)
async def workspaces_model_types(self, model_type: str, request: Request):
models_by_model_type = ProvidersWrapper(
provider_manager=self._provider_manager.provider_manager).get_models_by_model_type(model_type=model_type)
provider_manager=self._provider_manager.provider_manager
).get_models_by_model_type(model_type=model_type)
return ProviderModelTypeResponse(data=models_by_model_type)
async def list_models(self, provider: str, request: Request):
@ -403,17 +186,24 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return ModelList(data=models_list)
async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
)
response = None
return EmbeddingsResponse(**dictify(response))
model_instance = self._provider_manager.get_model_instance(
provider=provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embeddings_request.model,
)
texts = embeddings_request.input
if isinstance(texts, str):
texts = [texts]
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
return await openai_embedding_text(response)
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest
self, provider: str, request: Request, chat_request: ChatCompletionRequest
):
logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
@ -423,7 +213,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
provider=provider, model_type=ModelType.LLM, model=chat_request.model
)
prompt_messages = [
_convert_to_message(message) for message in chat_request.messages
convert_to_message(message) for message in chat_request.messages
]
tools = []
@ -458,11 +248,11 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
if chat_request.stream:
return EventSourceResponse(
_stream_openai_chat_completion(response),
stream_openai_chat_completion(response),
media_type="text/event-stream",
)
else:
return await _openai_chat_completion(response)
return await openai_chat_completion(response)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -473,9 +263,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def run(
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
cfg: Dict,
logging_conf: Optional[dict] = None,
started_event: mp.Event = None,
):
logging.config.dictConfig(logging_conf) # type: ignore
try:

View File

@ -1,5 +1,4 @@
from typing import Optional, List
from typing import List, Optional
from model_providers.bootstrap_web.entities.model_provider_entities import (
CustomConfigurationResponse,
@ -9,10 +8,8 @@ from model_providers.bootstrap_web.entities.model_provider_entities import (
ProviderWithModelsResponse,
SystemConfigurationResponse,
)
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.entities.provider_entities import ProviderType
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager
@ -22,7 +19,7 @@ class ProvidersWrapper:
self.provider_manager = provider_manager
def get_provider_list(
self, model_type: Optional[str] = None
self, model_type: Optional[str] = None
) -> List[ProviderResponse]:
"""
get provider list.
@ -38,8 +35,8 @@ class ProvidersWrapper:
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.get_configurations(provider=provider)
provider_configurations = self.provider_manager.get_configurations(
provider=provider
)
provider_responses = []
@ -47,8 +44,8 @@ class ProvidersWrapper:
if model_type:
model_type_entity = ModelType.value_of(model_type)
if (
model_type_entity
not in provider_configuration.provider.supported_model_types
model_type_entity
not in provider_configuration.provider.supported_model_types
):
continue
@ -78,7 +75,7 @@ class ProvidersWrapper:
return provider_responses
def get_models_by_model_type(
self, model_type: str
self, model_type: str
) -> List[ProviderWithModelsResponse]:
"""
get models by model type.
@ -94,8 +91,8 @@ class ProvidersWrapper:
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.get_configurations(provider=provider)
provider_configurations = self.provider_manager.get_configurations(
provider=provider
)
# Get provider available models

View File

@ -13,7 +13,12 @@ from google.generativeai.types import (
HarmBlockThreshold,
HarmCategory,
)
from google.generativeai.types.content_types import to_part, FunctionDeclaration, Tool, FunctionLibrary
from google.generativeai.types.content_types import (
FunctionDeclaration,
FunctionLibrary,
Tool,
to_part,
)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@ -58,15 +63,15 @@ if you are not sure about the structure.
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -83,15 +88,22 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -140,15 +152,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex))
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -163,9 +175,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs.pop(
"max_tokens_to_sample", None
)
config_kwargs.pop("max_tokens_to_sample", None)
# https://github.com/google/generative-ai-python/issues/170
# config_kwargs["max_output_tokens"] = config_kwargs.pop(
# "max_tokens_to_sample", None
@ -206,11 +216,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
}
tools_one = []
for tool in tools:
one_tool = Tool(function_declarations=[FunctionDeclaration(name=tool.name,
description=tool.description,
parameters=tool.parameters
)
])
one_tool = Tool(
function_declarations=[
FunctionDeclaration(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
]
)
tools_one.append(one_tool)
response = google_model.generate_content(
@ -231,11 +245,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
)
def _handle_generate_response(
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -262,7 +276,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
tool_calls.append(function_call)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=part.text, tool_calls=tool_calls)
assistant_prompt_message = AssistantPromptMessage(
content=part.text, tool_calls=tool_calls
)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@ -286,11 +302,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -446,7 +462,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
}
def _extract_response_function_call(
self, response_function_call: Union[FunctionCall, FunctionResponse]
self, response_function_call: Union[FunctionCall, FunctionResponse]
) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
@ -471,7 +487,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
arguments=str(map_composite_dict),
)
else:
raise ValueError(f"Unsupported response_function_call type: {type(response_function_call)}")
raise ValueError(
f"Unsupported response_function_call type: {type(response_function_call)}"
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name, type="function", function=function