embedding convert endpoint

This commit is contained in:
glide-the 2024-04-07 16:27:20 +08:00
parent 5169228b86
commit 051acfbeae
7 changed files with 412 additions and 332 deletions

View File

@ -3,11 +3,15 @@ import asyncio
import logging import logging
from model_providers import BootstrapWebBuilder from model_providers import BootstrapWebBuilder
from model_providers.core.utils.utils import get_config_dict, get_log_file, get_timestamp_ms from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--model-providers", "--model-providers",
@ -26,9 +30,7 @@ if __name__ == '__main__':
) )
boot = ( boot = (
BootstrapWebBuilder() BootstrapWebBuilder()
.model_providers_cfg_path( .model_providers_cfg_path(model_providers_cfg_path=args.model_providers)
model_providers_cfg_path=args.model_providers
)
.host(host="127.0.0.1") .host(host="127.0.0.1")
.port(port=20000) .port(port=20000)
.build() .build()
@ -36,11 +38,9 @@ if __name__ == '__main__':
boot.set_app_event(started_event=None) boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf) boot.serve(logging_conf=logging_conf)
async def pool_join_thread(): async def pool_join_thread():
await boot.join() await boot.join()
asyncio.run(pool_join_thread()) asyncio.run(pool_join_thread())
except SystemExit: except SystemExit:
logger.info("SystemExit raised, exiting") logger.info("SystemExit raised, exiting")

View File

@ -1,27 +0,0 @@
import typing
from subprocess import Popen
from typing import Optional
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
)
from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING:
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
def create_stream_chunk(
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
) -> str:
choice = ChatCompletionStreamResponseChoice(
index=index, delta=delta, finish_reason=finish_reason
)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk)

View File

@ -0,0 +1,13 @@
from model_providers.bootstrap_web.message_convert.core import (
convert_to_message,
openai_chat_completion,
openai_embedding_text,
stream_openai_chat_completion,
)
__all__ = [
"convert_to_message",
"stream_openai_chat_completion",
"openai_chat_completion",
"openai_embedding_text",
]

View File

@ -0,0 +1,289 @@
import logging
import typing
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
ChatMessage,
Embeddings,
EmbeddingsResponse,
Finish,
Role,
UsageInfo,
)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
)
from model_providers.core.utils.generic import jsonify
if typing.TYPE_CHECKING:
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
logger = logging.getLogger(__name__)
MessageLike = Union[ChatMessage, PromptMessage]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]
def create_stream_chunk(
request_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional[Finish] = None,
) -> str:
choice = ChatCompletionStreamResponseChoice(
index=index, delta=delta, finish_reason=finish_reason
)
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
return jsonify(chunk)
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if isinstance(template, str):
content = template
elif isinstance(template, list):
content = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append(TextPromptMessageContent(data=text))
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str):
img_template_obj = ImagePromptMessageContent(data=img_template)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
url = img_template["url"]
else:
url = None
img_template_obj = ImagePromptMessageContent(data=url)
else:
raise ValueError()
content.append(img_template_obj)
else:
raise ValueError()
else:
raise ValueError()
if message_type in ("human", "user"):
_message = UserPromptMessage(content=content)
elif message_type in ("ai", "assistant"):
_message = AssistantPromptMessage(content=content)
elif message_type == "system":
_message = SystemPromptMessage(content=content)
elif message_type in ("function", "tool"):
_message = ToolPromptMessage(content=content)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
)
return _message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(
message.role.to_origin_role(), message.content
)
elif isinstance(message, PromptMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
raise ValueError(f"Expected message type string, got {message_type_str}")
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
async def _stream_openai_chat_completion(
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
if not isinstance(chunk, LLMResultChunk):
yield "[ERROR]"
return
if model is None:
model = chunk.model
if request_id is None:
request_id = "request_id"
yield create_stream_chunk(
request_id,
model,
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
)
new_token = chunk.delta.message.content
if new_token:
delta = ChatCompletionMessage(
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls,
)
yield create_stream_chunk(
request_id=request_id,
model=model,
delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason,
)
yield create_stream_chunk(
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice(
index=0,
message=ChatCompletionMessage(
**_convert_prompt_message_to_dict(message=response.message)
),
finish_reason=Finish.STOP,
)
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ChatCompletionResponse(
id="request_id",
model=response.model,
choices=[choice],
usage=usage,
)
async def _openai_embedding_text(response: TextEmbeddingResult) -> EmbeddingsResponse:
embedding = [
Embeddings(embedding=embedding, index=index)
for index, embedding in enumerate(response.embeddings)
]
return EmbeddingsResponse(
model=response.model,
data=embedding,
usage=UsageInfo(
prompt_tokens=response.usage.tokens,
total_tokens=response.usage.total_tokens,
completion_tokens=response.usage.total_tokens,
),
)
convert_to_message = _convert_to_message
stream_openai_chat_completion = _stream_openai_chat_completion
openai_chat_completion = _openai_chat_completion
openai_embedding_text = _openai_embedding_text

View File

@ -23,256 +23,38 @@ from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from uvicorn import Config, Server from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.bootstrap_web.entities.model_provider_entities import ( from model_providers.bootstrap_web.entities.model_provider_entities import (
ProviderListResponse, ProviderListResponse,
ProviderModelTypeResponse, ProviderModelTypeResponse,
) )
from model_providers.bootstrap_web.message_convert import (
convert_to_message,
openai_chat_completion,
openai_embedding_text,
stream_openai_chat_completion,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import ( from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatMessage,
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
Finish,
FunctionAvailable,
ModelCard, ModelCard,
ModelList, ModelList,
Role,
UsageInfo,
) )
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
)
from model_providers.core.model_runtime.entities.message_entities import ( from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
) )
from model_providers.core.model_runtime.entities.model_entities import ( from model_providers.core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,
ModelType, ModelType,
) )
from model_providers.core.model_runtime.errors.invoke import InvokeError from model_providers.core.model_runtime.errors.invoke import InvokeError
from model_providers.core.utils.generic import dictify, jsonify from model_providers.core.utils.generic import dictify
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MessageLike = Union[ChatMessage, PromptMessage]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
) -> PromptMessage:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if isinstance(template, str):
content = template
elif isinstance(template, list):
content = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
content.append(TextPromptMessageContent(data=text))
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(dict, tmpl)["image_url"]
if isinstance(img_template, str):
img_template_obj = ImagePromptMessageContent(data=img_template)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
url = img_template["url"]
else:
url = None
img_template_obj = ImagePromptMessageContent(data=url)
else:
raise ValueError()
content.append(img_template_obj)
else:
raise ValueError()
else:
raise ValueError()
if message_type in ("human", "user"):
_message = UserPromptMessage(content=content)
elif message_type in ("ai", "assistant"):
_message = AssistantPromptMessage(content=content)
elif message_type == "system":
_message = SystemPromptMessage(content=content)
elif message_type in ("function", "tool"):
_message = ToolPromptMessage(content=content)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
)
return _message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[PromptMessage]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, ChatMessage):
_message = _create_template_from_message_type(
message.role.to_origin_role(), message.content
)
elif isinstance(message, PromptMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
raise ValueError(f"Expected message type string, got {message_type_str}")
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
async def _stream_openai_chat_completion(
response: Generator,
) -> AsyncGenerator[str, None]:
request_id, model = None, None
for chunk in response:
if not isinstance(chunk, LLMResultChunk):
yield "[ERROR]"
return
if model is None:
model = chunk.model
if request_id is None:
request_id = "request_id"
yield create_stream_chunk(
request_id,
model,
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
)
new_token = chunk.delta.message.content
if new_token:
delta = ChatCompletionMessage(
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
content=new_token,
tool_calls=chunk.delta.message.tool_calls,
)
yield create_stream_chunk(
request_id=request_id,
model=model,
delta=delta,
index=chunk.delta.index,
finish_reason=chunk.delta.finish_reason,
)
yield create_stream_chunk(
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
choice = ChatCompletionResponseChoice(
index=0,
message=ChatCompletionMessage(
**_convert_prompt_message_to_dict(message=response.message)
),
finish_reason=Finish.STOP,
)
usage = UsageInfo(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ChatCompletionResponse(
id="request_id",
model=response.model,
choices=[choice],
usage=usage,
)
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
""" """
@ -363,14 +145,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
started_event.set() started_event.set()
async def workspaces_model_providers(self, request: Request): async def workspaces_model_providers(self, request: Request):
provider_list = ProvidersWrapper(
provider_list = ProvidersWrapper(provider_manager=self._provider_manager.provider_manager).get_provider_list( provider_manager=self._provider_manager.provider_manager
model_type=request.get("model_type")) ).get_provider_list(model_type=request.get("model_type"))
return ProviderListResponse(data=provider_list) return ProviderListResponse(data=provider_list)
async def workspaces_model_types(self, model_type: str, request: Request): async def workspaces_model_types(self, model_type: str, request: Request):
models_by_model_type = ProvidersWrapper( models_by_model_type = ProvidersWrapper(
provider_manager=self._provider_manager.provider_manager).get_models_by_model_type(model_type=model_type) provider_manager=self._provider_manager.provider_manager
).get_models_by_model_type(model_type=model_type)
return ProviderModelTypeResponse(data=models_by_model_type) return ProviderModelTypeResponse(data=models_by_model_type)
async def list_models(self, provider: str, request: Request): async def list_models(self, provider: str, request: Request):
@ -403,17 +186,24 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
return ModelList(data=models_list) return ModelList(data=models_list)
async def create_embeddings( async def create_embeddings(
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
): ):
logger.info( logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
) )
model_instance = self._provider_manager.get_model_instance(
response = None provider=provider,
return EmbeddingsResponse(**dictify(response)) model_type=ModelType.TEXT_EMBEDDING,
model=embeddings_request.model,
)
texts = embeddings_request.input
if isinstance(texts, str):
texts = [texts]
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
return await openai_embedding_text(response)
async def create_chat_completion( async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest self, provider: str, request: Request, chat_request: ChatCompletionRequest
): ):
logger.info( logger.info(
f"Received chat completion request: {pprint.pformat(chat_request.dict())}" f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
@ -423,7 +213,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
provider=provider, model_type=ModelType.LLM, model=chat_request.model provider=provider, model_type=ModelType.LLM, model=chat_request.model
) )
prompt_messages = [ prompt_messages = [
_convert_to_message(message) for message in chat_request.messages convert_to_message(message) for message in chat_request.messages
] ]
tools = [] tools = []
@ -458,11 +248,11 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
if chat_request.stream: if chat_request.stream:
return EventSourceResponse( return EventSourceResponse(
_stream_openai_chat_completion(response), stream_openai_chat_completion(response),
media_type="text/event-stream", media_type="text/event-stream",
) )
else: else:
return await _openai_chat_completion(response) return await openai_chat_completion(response)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -473,9 +263,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
def run( def run(
cfg: Dict, cfg: Dict,
logging_conf: Optional[dict] = None, logging_conf: Optional[dict] = None,
started_event: mp.Event = None, started_event: mp.Event = None,
): ):
logging.config.dictConfig(logging_conf) # type: ignore logging.config.dictConfig(logging_conf) # type: ignore
try: try:

View File

@ -1,5 +1,4 @@
from typing import Optional, List from typing import List, Optional
from model_providers.bootstrap_web.entities.model_provider_entities import ( from model_providers.bootstrap_web.entities.model_provider_entities import (
CustomConfigurationResponse, CustomConfigurationResponse,
@ -9,10 +8,8 @@ from model_providers.bootstrap_web.entities.model_provider_entities import (
ProviderWithModelsResponse, ProviderWithModelsResponse,
SystemConfigurationResponse, SystemConfigurationResponse,
) )
from model_providers.core.entities.model_entities import ModelStatus from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.entities.provider_entities import ProviderType from model_providers.core.entities.provider_entities import ProviderType
from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.provider_manager import ProviderManager from model_providers.core.provider_manager import ProviderManager
@ -22,7 +19,7 @@ class ProvidersWrapper:
self.provider_manager = provider_manager self.provider_manager = provider_manager
def get_provider_list( def get_provider_list(
self, model_type: Optional[str] = None self, model_type: Optional[str] = None
) -> List[ProviderResponse]: ) -> List[ProviderResponse]:
""" """
get provider list. get provider list.
@ -38,8 +35,8 @@ class ProvidersWrapper:
self.provider_manager.provider_name_to_provider_model_records_dict.keys() self.provider_manager.provider_name_to_provider_model_records_dict.keys()
) )
# Get all provider configurations of the current workspace # Get all provider configurations of the current workspace
provider_configurations = ( provider_configurations = self.provider_manager.get_configurations(
self.provider_manager.get_configurations(provider=provider) provider=provider
) )
provider_responses = [] provider_responses = []
@ -47,8 +44,8 @@ class ProvidersWrapper:
if model_type: if model_type:
model_type_entity = ModelType.value_of(model_type) model_type_entity = ModelType.value_of(model_type)
if ( if (
model_type_entity model_type_entity
not in provider_configuration.provider.supported_model_types not in provider_configuration.provider.supported_model_types
): ):
continue continue
@ -78,7 +75,7 @@ class ProvidersWrapper:
return provider_responses return provider_responses
def get_models_by_model_type( def get_models_by_model_type(
self, model_type: str self, model_type: str
) -> List[ProviderWithModelsResponse]: ) -> List[ProviderWithModelsResponse]:
""" """
get models by model type. get models by model type.
@ -94,8 +91,8 @@ class ProvidersWrapper:
self.provider_manager.provider_name_to_provider_model_records_dict.keys() self.provider_manager.provider_name_to_provider_model_records_dict.keys()
) )
# Get all provider configurations of the current workspace # Get all provider configurations of the current workspace
provider_configurations = ( provider_configurations = self.provider_manager.get_configurations(
self.provider_manager.get_configurations(provider=provider) provider=provider
) )
# Get provider available models # Get provider available models

View File

@ -13,7 +13,12 @@ from google.generativeai.types import (
HarmBlockThreshold, HarmBlockThreshold,
HarmCategory, HarmCategory,
) )
from google.generativeai.types.content_types import to_part, FunctionDeclaration, Tool, FunctionLibrary from google.generativeai.types.content_types import (
FunctionDeclaration,
FunctionLibrary,
Tool,
to_part,
)
from model_providers.core.model_runtime.entities.llm_entities import ( from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult, LLMResult,
@ -58,15 +63,15 @@ if you are not sure about the structure.
class GoogleLargeLanguageModel(LargeLanguageModel): class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke( def _invoke(
self, self,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model
@ -83,15 +88,22 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate( return self._generate(
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
) )
def get_num_tokens( def get_num_tokens(
self, self,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
) -> int: ) -> int:
""" """
Get number of tokens for given prompt messages Get number of tokens for given prompt messages
@ -140,15 +152,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
def _generate( def _generate(
self, self,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model
@ -163,9 +175,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
config_kwargs = model_parameters.copy() config_kwargs = model_parameters.copy()
config_kwargs.pop( config_kwargs.pop("max_tokens_to_sample", None)
"max_tokens_to_sample", None
)
# https://github.com/google/generative-ai-python/issues/170 # https://github.com/google/generative-ai-python/issues/170
# config_kwargs["max_output_tokens"] = config_kwargs.pop( # config_kwargs["max_output_tokens"] = config_kwargs.pop(
# "max_tokens_to_sample", None # "max_tokens_to_sample", None
@ -206,11 +216,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
} }
tools_one = [] tools_one = []
for tool in tools: for tool in tools:
one_tool = Tool(function_declarations=[FunctionDeclaration(name=tool.name, one_tool = Tool(
description=tool.description, function_declarations=[
parameters=tool.parameters FunctionDeclaration(
) name=tool.name,
]) description=tool.description,
parameters=tool.parameters,
)
]
)
tools_one.append(one_tool) tools_one.append(one_tool)
response = google_model.generate_content( response = google_model.generate_content(
@ -231,11 +245,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
) )
def _handle_generate_response( def _handle_generate_response(
self, self,
model: str, model: str,
credentials: dict, credentials: dict,
response: GenerateContentResponse, response: GenerateContentResponse,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
) -> LLMResult: ) -> LLMResult:
""" """
Handle llm response Handle llm response
@ -262,7 +276,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
tool_calls.append(function_call) tool_calls.append(function_call)
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=part.text, tool_calls=tool_calls) assistant_prompt_message = AssistantPromptMessage(
content=part.text, tool_calls=tool_calls
)
# calculate num tokens # calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@ -286,11 +302,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return result return result
def _handle_generate_stream_response( def _handle_generate_stream_response(
self, self,
model: str, model: str,
credentials: dict, credentials: dict,
response: GenerateContentResponse, response: GenerateContentResponse,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
) -> Generator: ) -> Generator:
""" """
Handle llm stream response Handle llm stream response
@ -446,7 +462,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
} }
def _extract_response_function_call( def _extract_response_function_call(
self, response_function_call: Union[FunctionCall, FunctionResponse] self, response_function_call: Union[FunctionCall, FunctionResponse]
) -> AssistantPromptMessage.ToolCall: ) -> AssistantPromptMessage.ToolCall:
""" """
Extract function call from response Extract function call from response
@ -471,7 +487,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
arguments=str(map_composite_dict), arguments=str(map_composite_dict),
) )
else: else:
raise ValueError(f"Unsupported response_function_call type: {type(response_function_call)}") raise ValueError(
f"Unsupported response_function_call type: {type(response_function_call)}"
)
tool_call = AssistantPromptMessage.ToolCall( tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name, type="function", function=function id=response_function_call.name, type="function", function=function