mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-05 06:03:14 +08:00
embedding convert endpoint
This commit is contained in:
parent
5169228b86
commit
051acfbeae
@ -3,11 +3,15 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers import BootstrapWebBuilder
|
from model_providers import BootstrapWebBuilder
|
||||||
from model_providers.core.utils.utils import get_config_dict, get_log_file, get_timestamp_ms
|
from model_providers.core.utils.utils import (
|
||||||
|
get_config_dict,
|
||||||
|
get_log_file,
|
||||||
|
get_timestamp_ms,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-providers",
|
"--model-providers",
|
||||||
@ -26,9 +30,7 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
boot = (
|
boot = (
|
||||||
BootstrapWebBuilder()
|
BootstrapWebBuilder()
|
||||||
.model_providers_cfg_path(
|
.model_providers_cfg_path(model_providers_cfg_path=args.model_providers)
|
||||||
model_providers_cfg_path=args.model_providers
|
|
||||||
)
|
|
||||||
.host(host="127.0.0.1")
|
.host(host="127.0.0.1")
|
||||||
.port(port=20000)
|
.port(port=20000)
|
||||||
.build()
|
.build()
|
||||||
@ -36,11 +38,9 @@ if __name__ == '__main__':
|
|||||||
boot.set_app_event(started_event=None)
|
boot.set_app_event(started_event=None)
|
||||||
boot.serve(logging_conf=logging_conf)
|
boot.serve(logging_conf=logging_conf)
|
||||||
|
|
||||||
|
|
||||||
async def pool_join_thread():
|
async def pool_join_thread():
|
||||||
await boot.join()
|
await boot.join()
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(pool_join_thread())
|
asyncio.run(pool_join_thread())
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
logger.info("SystemExit raised, exiting")
|
logger.info("SystemExit raised, exiting")
|
||||||
|
|||||||
@ -1,27 +0,0 @@
|
|||||||
import typing
|
|
||||||
from subprocess import Popen
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from model_providers.core.bootstrap.openai_protocol import (
|
|
||||||
ChatCompletionStreamResponse,
|
|
||||||
ChatCompletionStreamResponseChoice,
|
|
||||||
Finish,
|
|
||||||
)
|
|
||||||
from model_providers.core.utils.generic import jsonify
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
|
|
||||||
|
|
||||||
|
|
||||||
def create_stream_chunk(
|
|
||||||
request_id: str,
|
|
||||||
model: str,
|
|
||||||
delta: "ChatCompletionMessage",
|
|
||||||
index: Optional[int] = 0,
|
|
||||||
finish_reason: Optional[Finish] = None,
|
|
||||||
) -> str:
|
|
||||||
choice = ChatCompletionStreamResponseChoice(
|
|
||||||
index=index, delta=delta, finish_reason=finish_reason
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
|
|
||||||
return jsonify(chunk)
|
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
from model_providers.bootstrap_web.message_convert.core import (
|
||||||
|
convert_to_message,
|
||||||
|
openai_chat_completion,
|
||||||
|
openai_embedding_text,
|
||||||
|
stream_openai_chat_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"convert_to_message",
|
||||||
|
"stream_openai_chat_completion",
|
||||||
|
"openai_chat_completion",
|
||||||
|
"openai_embedding_text",
|
||||||
|
]
|
||||||
@ -0,0 +1,289 @@
|
|||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from model_providers.core.bootstrap.openai_protocol import (
|
||||||
|
ChatCompletionMessage,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
ChatCompletionStreamResponseChoice,
|
||||||
|
ChatMessage,
|
||||||
|
Embeddings,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
Finish,
|
||||||
|
Role,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
|
from model_providers.core.utils.generic import jsonify
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from model_providers.core.bootstrap.openai_protocol import ChatCompletionMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MessageLike = Union[ChatMessage, PromptMessage]
|
||||||
|
|
||||||
|
MessageLikeRepresentation = Union[
|
||||||
|
MessageLike,
|
||||||
|
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
|
||||||
|
str,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_stream_chunk(
|
||||||
|
request_id: str,
|
||||||
|
model: str,
|
||||||
|
delta: "ChatCompletionMessage",
|
||||||
|
index: Optional[int] = 0,
|
||||||
|
finish_reason: Optional[Finish] = None,
|
||||||
|
) -> str:
|
||||||
|
choice = ChatCompletionStreamResponseChoice(
|
||||||
|
index=index, delta=delta, finish_reason=finish_reason
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(id=request_id, model=model, choices=[choice])
|
||||||
|
return jsonify(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
||||||
|
"""
|
||||||
|
Convert PromptMessage to dict for OpenAI Compatibility API
|
||||||
|
"""
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError("User message content must be str")
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if message.tool_calls and len(message.tool_calls) > 0:
|
||||||
|
message_dict["function_call"] = {
|
||||||
|
"name": message.tool_calls[0].function.name,
|
||||||
|
"arguments": message.tool_calls[0].function.arguments,
|
||||||
|
}
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message = cast(SystemPromptMessage, message)
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
# check if last message is user message
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {"role": "function", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _create_template_from_message_type(
|
||||||
|
message_type: str, template: Union[str, list]
|
||||||
|
) -> PromptMessage:
|
||||||
|
"""Create a message prompt template from a message type and template string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_type: str the type of the message template (e.g., "human", "ai", etc.)
|
||||||
|
template: str the template string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a message prompt template of the appropriate type.
|
||||||
|
"""
|
||||||
|
if isinstance(template, str):
|
||||||
|
content = template
|
||||||
|
elif isinstance(template, list):
|
||||||
|
content = []
|
||||||
|
for tmpl in template:
|
||||||
|
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
|
||||||
|
if isinstance(tmpl, str):
|
||||||
|
text: str = tmpl
|
||||||
|
else:
|
||||||
|
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
|
||||||
|
content.append(TextPromptMessageContent(data=text))
|
||||||
|
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
||||||
|
img_template = cast(dict, tmpl)["image_url"]
|
||||||
|
if isinstance(img_template, str):
|
||||||
|
img_template_obj = ImagePromptMessageContent(data=img_template)
|
||||||
|
elif isinstance(img_template, dict):
|
||||||
|
img_template = dict(img_template)
|
||||||
|
if "url" in img_template:
|
||||||
|
url = img_template["url"]
|
||||||
|
else:
|
||||||
|
url = None
|
||||||
|
img_template_obj = ImagePromptMessageContent(data=url)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
content.append(img_template_obj)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
if message_type in ("human", "user"):
|
||||||
|
_message = UserPromptMessage(content=content)
|
||||||
|
elif message_type in ("ai", "assistant"):
|
||||||
|
_message = AssistantPromptMessage(content=content)
|
||||||
|
elif message_type == "system":
|
||||||
|
_message = SystemPromptMessage(content=content)
|
||||||
|
elif message_type in ("function", "tool"):
|
||||||
|
_message = ToolPromptMessage(content=content)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||||
|
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
|
||||||
|
)
|
||||||
|
|
||||||
|
return _message
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_message(
|
||||||
|
message: MessageLikeRepresentation,
|
||||||
|
) -> Union[PromptMessage]:
|
||||||
|
"""Instantiate a message from a variety of message formats.
|
||||||
|
|
||||||
|
The message format can be one of the following:
|
||||||
|
|
||||||
|
- BaseMessagePromptTemplate
|
||||||
|
- BaseMessage
|
||||||
|
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
||||||
|
- 2-tuple of (message class, template)
|
||||||
|
- string: shorthand for ("human", template); e.g., "{user_input}"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: a representation of a message in one of the supported formats
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an instance of a message or a message template
|
||||||
|
"""
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
_message = _create_template_from_message_type(
|
||||||
|
message.role.to_origin_role(), message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(message, PromptMessage):
|
||||||
|
_message = message
|
||||||
|
elif isinstance(message, str):
|
||||||
|
_message = _create_template_from_message_type("human", message)
|
||||||
|
elif isinstance(message, tuple):
|
||||||
|
if len(message) != 2:
|
||||||
|
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||||
|
message_type_str, template = message
|
||||||
|
if isinstance(message_type_str, str):
|
||||||
|
_message = _create_template_from_message_type(message_type_str, template)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected message type string, got {message_type_str}")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
return _message
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_openai_chat_completion(
|
||||||
|
response: Generator,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
request_id, model = None, None
|
||||||
|
for chunk in response:
|
||||||
|
if not isinstance(chunk, LLMResultChunk):
|
||||||
|
yield "[ERROR]"
|
||||||
|
return
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
model = chunk.model
|
||||||
|
if request_id is None:
|
||||||
|
request_id = "request_id"
|
||||||
|
yield create_stream_chunk(
|
||||||
|
request_id,
|
||||||
|
model,
|
||||||
|
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
new_token = chunk.delta.message.content
|
||||||
|
|
||||||
|
if new_token:
|
||||||
|
delta = ChatCompletionMessage(
|
||||||
|
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
|
||||||
|
content=new_token,
|
||||||
|
tool_calls=chunk.delta.message.tool_calls,
|
||||||
|
)
|
||||||
|
yield create_stream_chunk(
|
||||||
|
request_id=request_id,
|
||||||
|
model=model,
|
||||||
|
delta=delta,
|
||||||
|
index=chunk.delta.index,
|
||||||
|
finish_reason=chunk.delta.finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield create_stream_chunk(
|
||||||
|
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||||
|
)
|
||||||
|
yield "[DONE]"
|
||||||
|
|
||||||
|
|
||||||
|
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
|
||||||
|
choice = ChatCompletionResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
**_convert_prompt_message_to_dict(message=response.message)
|
||||||
|
),
|
||||||
|
finish_reason=Finish.STOP,
|
||||||
|
)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
completion_tokens=response.usage.completion_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
id="request_id",
|
||||||
|
model=response.model,
|
||||||
|
choices=[choice],
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _openai_embedding_text(response: TextEmbeddingResult) -> EmbeddingsResponse:
|
||||||
|
embedding = [
|
||||||
|
Embeddings(embedding=embedding, index=index)
|
||||||
|
for index, embedding in enumerate(response.embeddings)
|
||||||
|
]
|
||||||
|
|
||||||
|
return EmbeddingsResponse(
|
||||||
|
model=response.model,
|
||||||
|
data=embedding,
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=response.usage.tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
completion_tokens=response.usage.total_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
convert_to_message = _convert_to_message
|
||||||
|
stream_openai_chat_completion = _stream_openai_chat_completion
|
||||||
|
openai_chat_completion = _openai_chat_completion
|
||||||
|
openai_embedding_text = _openai_embedding_text
|
||||||
@ -23,256 +23,38 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from uvicorn import Config, Server
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
from model_providers.bootstrap_web.common import create_stream_chunk
|
|
||||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||||
ProviderListResponse,
|
ProviderListResponse,
|
||||||
ProviderModelTypeResponse,
|
ProviderModelTypeResponse,
|
||||||
)
|
)
|
||||||
|
from model_providers.bootstrap_web.message_convert import (
|
||||||
|
convert_to_message,
|
||||||
|
openai_chat_completion,
|
||||||
|
openai_embedding_text,
|
||||||
|
stream_openai_chat_completion,
|
||||||
|
)
|
||||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||||
from model_providers.core.bootstrap.openai_protocol import (
|
from model_providers.core.bootstrap.openai_protocol import (
|
||||||
ChatCompletionMessage,
|
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice,
|
|
||||||
ChatCompletionStreamResponse,
|
|
||||||
ChatMessage,
|
|
||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
Finish,
|
|
||||||
FunctionAvailable,
|
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelList,
|
ModelList,
|
||||||
Role,
|
|
||||||
UsageInfo,
|
|
||||||
)
|
)
|
||||||
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
|
||||||
LLMResult,
|
|
||||||
LLMResultChunk,
|
|
||||||
)
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
|
||||||
ImagePromptMessageContent,
|
|
||||||
PromptMessage,
|
|
||||||
PromptMessageContent,
|
|
||||||
PromptMessageContentType,
|
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
|
||||||
TextPromptMessageContent,
|
|
||||||
ToolPromptMessage,
|
|
||||||
UserPromptMessage,
|
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import (
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
AIModelEntity,
|
AIModelEntity,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import InvokeError
|
from model_providers.core.model_runtime.errors.invoke import InvokeError
|
||||||
from model_providers.core.utils.generic import dictify, jsonify
|
from model_providers.core.utils.generic import dictify
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MessageLike = Union[ChatMessage, PromptMessage]
|
|
||||||
|
|
||||||
MessageLikeRepresentation = Union[
|
|
||||||
MessageLike,
|
|
||||||
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
|
|
||||||
str,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
|
||||||
"""
|
|
||||||
Convert PromptMessage to dict for OpenAI Compatibility API
|
|
||||||
"""
|
|
||||||
if isinstance(message, UserPromptMessage):
|
|
||||||
message = cast(UserPromptMessage, message)
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
message_dict = {"role": "user", "content": message.content}
|
|
||||||
else:
|
|
||||||
raise ValueError("User message content must be str")
|
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
|
||||||
message = cast(AssistantPromptMessage, message)
|
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
|
||||||
if message.tool_calls and len(message.tool_calls) > 0:
|
|
||||||
message_dict["function_call"] = {
|
|
||||||
"name": message.tool_calls[0].function.name,
|
|
||||||
"arguments": message.tool_calls[0].function.arguments,
|
|
||||||
}
|
|
||||||
elif isinstance(message, SystemPromptMessage):
|
|
||||||
message = cast(SystemPromptMessage, message)
|
|
||||||
message_dict = {"role": "system", "content": message.content}
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
# check if last message is user message
|
|
||||||
message = cast(ToolPromptMessage, message)
|
|
||||||
message_dict = {"role": "function", "content": message.content}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
|
||||||
|
|
||||||
return message_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _create_template_from_message_type(
|
|
||||||
message_type: str, template: Union[str, list]
|
|
||||||
) -> PromptMessage:
|
|
||||||
"""Create a message prompt template from a message type and template string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_type: str the type of the message template (e.g., "human", "ai", etc.)
|
|
||||||
template: str the template string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a message prompt template of the appropriate type.
|
|
||||||
"""
|
|
||||||
if isinstance(template, str):
|
|
||||||
content = template
|
|
||||||
elif isinstance(template, list):
|
|
||||||
content = []
|
|
||||||
for tmpl in template:
|
|
||||||
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
|
|
||||||
if isinstance(tmpl, str):
|
|
||||||
text: str = tmpl
|
|
||||||
else:
|
|
||||||
text = cast(dict, tmpl)["text"] # type: ignore[assignment] # noqa: E501
|
|
||||||
content.append(TextPromptMessageContent(data=text))
|
|
||||||
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
|
||||||
img_template = cast(dict, tmpl)["image_url"]
|
|
||||||
if isinstance(img_template, str):
|
|
||||||
img_template_obj = ImagePromptMessageContent(data=img_template)
|
|
||||||
elif isinstance(img_template, dict):
|
|
||||||
img_template = dict(img_template)
|
|
||||||
if "url" in img_template:
|
|
||||||
url = img_template["url"]
|
|
||||||
else:
|
|
||||||
url = None
|
|
||||||
img_template_obj = ImagePromptMessageContent(data=url)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
content.append(img_template_obj)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
if message_type in ("human", "user"):
|
|
||||||
_message = UserPromptMessage(content=content)
|
|
||||||
elif message_type in ("ai", "assistant"):
|
|
||||||
_message = AssistantPromptMessage(content=content)
|
|
||||||
elif message_type == "system":
|
|
||||||
_message = SystemPromptMessage(content=content)
|
|
||||||
elif message_type in ("function", "tool"):
|
|
||||||
_message = ToolPromptMessage(content=content)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
|
||||||
f" 'user', 'ai', 'assistant', or 'system' and 'function' or 'tool'."
|
|
||||||
)
|
|
||||||
|
|
||||||
return _message
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_message(
|
|
||||||
message: MessageLikeRepresentation,
|
|
||||||
) -> Union[PromptMessage]:
|
|
||||||
"""Instantiate a message from a variety of message formats.
|
|
||||||
|
|
||||||
The message format can be one of the following:
|
|
||||||
|
|
||||||
- BaseMessagePromptTemplate
|
|
||||||
- BaseMessage
|
|
||||||
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
|
|
||||||
- 2-tuple of (message class, template)
|
|
||||||
- string: shorthand for ("human", template); e.g., "{user_input}"
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: a representation of a message in one of the supported formats
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
an instance of a message or a message template
|
|
||||||
"""
|
|
||||||
if isinstance(message, ChatMessage):
|
|
||||||
_message = _create_template_from_message_type(
|
|
||||||
message.role.to_origin_role(), message.content
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(message, PromptMessage):
|
|
||||||
_message = message
|
|
||||||
elif isinstance(message, str):
|
|
||||||
_message = _create_template_from_message_type("human", message)
|
|
||||||
elif isinstance(message, tuple):
|
|
||||||
if len(message) != 2:
|
|
||||||
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
|
||||||
message_type_str, template = message
|
|
||||||
if isinstance(message_type_str, str):
|
|
||||||
_message = _create_template_from_message_type(message_type_str, template)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Expected message type string, got {message_type_str}")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
|
||||||
|
|
||||||
return _message
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_openai_chat_completion(
|
|
||||||
response: Generator,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
request_id, model = None, None
|
|
||||||
for chunk in response:
|
|
||||||
if not isinstance(chunk, LLMResultChunk):
|
|
||||||
yield "[ERROR]"
|
|
||||||
return
|
|
||||||
|
|
||||||
if model is None:
|
|
||||||
model = chunk.model
|
|
||||||
if request_id is None:
|
|
||||||
request_id = "request_id"
|
|
||||||
yield create_stream_chunk(
|
|
||||||
request_id,
|
|
||||||
model,
|
|
||||||
ChatCompletionMessage(role=Role.ASSISTANT, content=""),
|
|
||||||
)
|
|
||||||
|
|
||||||
new_token = chunk.delta.message.content
|
|
||||||
|
|
||||||
if new_token:
|
|
||||||
delta = ChatCompletionMessage(
|
|
||||||
role=Role.value_of(chunk.delta.message.role.to_origin_role()),
|
|
||||||
content=new_token,
|
|
||||||
tool_calls=chunk.delta.message.tool_calls,
|
|
||||||
)
|
|
||||||
yield create_stream_chunk(
|
|
||||||
request_id=request_id,
|
|
||||||
model=model,
|
|
||||||
delta=delta,
|
|
||||||
index=chunk.delta.index,
|
|
||||||
finish_reason=chunk.delta.finish_reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield create_stream_chunk(
|
|
||||||
request_id, model, ChatCompletionMessage(), finish_reason=Finish.STOP
|
|
||||||
)
|
|
||||||
yield "[DONE]"
|
|
||||||
|
|
||||||
|
|
||||||
async def _openai_chat_completion(response: LLMResult) -> ChatCompletionResponse:
|
|
||||||
choice = ChatCompletionResponseChoice(
|
|
||||||
index=0,
|
|
||||||
message=ChatCompletionMessage(
|
|
||||||
**_convert_prompt_message_to_dict(message=response.message)
|
|
||||||
),
|
|
||||||
finish_reason=Finish.STOP,
|
|
||||||
)
|
|
||||||
usage = UsageInfo(
|
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
|
||||||
completion_tokens=response.usage.completion_tokens,
|
|
||||||
total_tokens=response.usage.total_tokens,
|
|
||||||
)
|
|
||||||
return ChatCompletionResponse(
|
|
||||||
id="request_id",
|
|
||||||
model=response.model,
|
|
||||||
choices=[choice],
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||||
"""
|
"""
|
||||||
@ -363,14 +145,15 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
started_event.set()
|
started_event.set()
|
||||||
|
|
||||||
async def workspaces_model_providers(self, request: Request):
|
async def workspaces_model_providers(self, request: Request):
|
||||||
|
provider_list = ProvidersWrapper(
|
||||||
provider_list = ProvidersWrapper(provider_manager=self._provider_manager.provider_manager).get_provider_list(
|
provider_manager=self._provider_manager.provider_manager
|
||||||
model_type=request.get("model_type"))
|
).get_provider_list(model_type=request.get("model_type"))
|
||||||
return ProviderListResponse(data=provider_list)
|
return ProviderListResponse(data=provider_list)
|
||||||
|
|
||||||
async def workspaces_model_types(self, model_type: str, request: Request):
|
async def workspaces_model_types(self, model_type: str, request: Request):
|
||||||
models_by_model_type = ProvidersWrapper(
|
models_by_model_type = ProvidersWrapper(
|
||||||
provider_manager=self._provider_manager.provider_manager).get_models_by_model_type(model_type=model_type)
|
provider_manager=self._provider_manager.provider_manager
|
||||||
|
).get_models_by_model_type(model_type=model_type)
|
||||||
return ProviderModelTypeResponse(data=models_by_model_type)
|
return ProviderModelTypeResponse(data=models_by_model_type)
|
||||||
|
|
||||||
async def list_models(self, provider: str, request: Request):
|
async def list_models(self, provider: str, request: Request):
|
||||||
@ -403,17 +186,24 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
return ModelList(data=models_list)
|
return ModelList(data=models_list)
|
||||||
|
|
||||||
async def create_embeddings(
|
async def create_embeddings(
|
||||||
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
self, provider: str, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||||
)
|
)
|
||||||
|
model_instance = self._provider_manager.get_model_instance(
|
||||||
response = None
|
provider=provider,
|
||||||
return EmbeddingsResponse(**dictify(response))
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
model=embeddings_request.model,
|
||||||
|
)
|
||||||
|
texts = embeddings_request.input
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
|
||||||
|
return await openai_embedding_text(response)
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||||
@ -423,7 +213,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
provider=provider, model_type=ModelType.LLM, model=chat_request.model
|
provider=provider, model_type=ModelType.LLM, model=chat_request.model
|
||||||
)
|
)
|
||||||
prompt_messages = [
|
prompt_messages = [
|
||||||
_convert_to_message(message) for message in chat_request.messages
|
convert_to_message(message) for message in chat_request.messages
|
||||||
]
|
]
|
||||||
|
|
||||||
tools = []
|
tools = []
|
||||||
@ -458,11 +248,11 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
|
|
||||||
if chat_request.stream:
|
if chat_request.stream:
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
_stream_openai_chat_completion(response),
|
stream_openai_chat_completion(response),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await _openai_chat_completion(response)
|
return await openai_chat_completion(response)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
|
||||||
@ -473,9 +263,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
cfg: Dict,
|
cfg: Dict,
|
||||||
logging_conf: Optional[dict] = None,
|
logging_conf: Optional[dict] = None,
|
||||||
started_event: mp.Event = None,
|
started_event: mp.Event = None,
|
||||||
):
|
):
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from typing import Optional, List
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||||
CustomConfigurationResponse,
|
CustomConfigurationResponse,
|
||||||
@ -9,10 +8,8 @@ from model_providers.bootstrap_web.entities.model_provider_entities import (
|
|||||||
ProviderWithModelsResponse,
|
ProviderWithModelsResponse,
|
||||||
SystemConfigurationResponse,
|
SystemConfigurationResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
from model_providers.core.entities.model_entities import ModelStatus
|
from model_providers.core.entities.model_entities import ModelStatus
|
||||||
from model_providers.core.entities.provider_entities import ProviderType
|
from model_providers.core.entities.provider_entities import ProviderType
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.provider_manager import ProviderManager
|
from model_providers.core.provider_manager import ProviderManager
|
||||||
|
|
||||||
@ -22,7 +19,7 @@ class ProvidersWrapper:
|
|||||||
self.provider_manager = provider_manager
|
self.provider_manager = provider_manager
|
||||||
|
|
||||||
def get_provider_list(
|
def get_provider_list(
|
||||||
self, model_type: Optional[str] = None
|
self, model_type: Optional[str] = None
|
||||||
) -> List[ProviderResponse]:
|
) -> List[ProviderResponse]:
|
||||||
"""
|
"""
|
||||||
get provider list.
|
get provider list.
|
||||||
@ -38,8 +35,8 @@ class ProvidersWrapper:
|
|||||||
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||||
)
|
)
|
||||||
# Get all provider configurations of the current workspace
|
# Get all provider configurations of the current workspace
|
||||||
provider_configurations = (
|
provider_configurations = self.provider_manager.get_configurations(
|
||||||
self.provider_manager.get_configurations(provider=provider)
|
provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_responses = []
|
provider_responses = []
|
||||||
@ -47,8 +44,8 @@ class ProvidersWrapper:
|
|||||||
if model_type:
|
if model_type:
|
||||||
model_type_entity = ModelType.value_of(model_type)
|
model_type_entity = ModelType.value_of(model_type)
|
||||||
if (
|
if (
|
||||||
model_type_entity
|
model_type_entity
|
||||||
not in provider_configuration.provider.supported_model_types
|
not in provider_configuration.provider.supported_model_types
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -78,7 +75,7 @@ class ProvidersWrapper:
|
|||||||
return provider_responses
|
return provider_responses
|
||||||
|
|
||||||
def get_models_by_model_type(
|
def get_models_by_model_type(
|
||||||
self, model_type: str
|
self, model_type: str
|
||||||
) -> List[ProviderWithModelsResponse]:
|
) -> List[ProviderWithModelsResponse]:
|
||||||
"""
|
"""
|
||||||
get models by model type.
|
get models by model type.
|
||||||
@ -94,8 +91,8 @@ class ProvidersWrapper:
|
|||||||
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
self.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||||
)
|
)
|
||||||
# Get all provider configurations of the current workspace
|
# Get all provider configurations of the current workspace
|
||||||
provider_configurations = (
|
provider_configurations = self.provider_manager.get_configurations(
|
||||||
self.provider_manager.get_configurations(provider=provider)
|
provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get provider available models
|
# Get provider available models
|
||||||
|
|||||||
@ -13,7 +13,12 @@ from google.generativeai.types import (
|
|||||||
HarmBlockThreshold,
|
HarmBlockThreshold,
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
)
|
)
|
||||||
from google.generativeai.types.content_types import to_part, FunctionDeclaration, Tool, FunctionLibrary
|
from google.generativeai.types.content_types import (
|
||||||
|
FunctionDeclaration,
|
||||||
|
FunctionLibrary,
|
||||||
|
Tool,
|
||||||
|
to_part,
|
||||||
|
)
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
LLMResult,
|
LLMResult,
|
||||||
@ -58,15 +63,15 @@ if you are not sure about the structure.
|
|||||||
|
|
||||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
@ -83,15 +88,22 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
# invoke model
|
# invoke model
|
||||||
return self._generate(
|
return self._generate(
|
||||||
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
|
model,
|
||||||
|
credentials,
|
||||||
|
prompt_messages,
|
||||||
|
model_parameters,
|
||||||
|
tools,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(
|
def get_num_tokens(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
@ -140,15 +152,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
@ -163,9 +175,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
config_kwargs = model_parameters.copy()
|
config_kwargs = model_parameters.copy()
|
||||||
config_kwargs.pop(
|
config_kwargs.pop("max_tokens_to_sample", None)
|
||||||
"max_tokens_to_sample", None
|
|
||||||
)
|
|
||||||
# https://github.com/google/generative-ai-python/issues/170
|
# https://github.com/google/generative-ai-python/issues/170
|
||||||
# config_kwargs["max_output_tokens"] = config_kwargs.pop(
|
# config_kwargs["max_output_tokens"] = config_kwargs.pop(
|
||||||
# "max_tokens_to_sample", None
|
# "max_tokens_to_sample", None
|
||||||
@ -206,11 +216,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
}
|
}
|
||||||
tools_one = []
|
tools_one = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
one_tool = Tool(function_declarations=[FunctionDeclaration(name=tool.name,
|
one_tool = Tool(
|
||||||
description=tool.description,
|
function_declarations=[
|
||||||
parameters=tool.parameters
|
FunctionDeclaration(
|
||||||
)
|
name=tool.name,
|
||||||
])
|
description=tool.description,
|
||||||
|
parameters=tool.parameters,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
tools_one.append(one_tool)
|
tools_one.append(one_tool)
|
||||||
|
|
||||||
response = google_model.generate_content(
|
response = google_model.generate_content(
|
||||||
@ -231,11 +245,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _handle_generate_response(
|
def _handle_generate_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
response: GenerateContentResponse,
|
response: GenerateContentResponse,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm response
|
Handle llm response
|
||||||
@ -262,7 +276,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
tool_calls.append(function_call)
|
tool_calls.append(function_call)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=part.text, tool_calls=tool_calls)
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=part.text, tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
@ -286,11 +302,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle_generate_stream_response(
|
def _handle_generate_stream_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
response: GenerateContentResponse,
|
response: GenerateContentResponse,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm stream response
|
Handle llm stream response
|
||||||
@ -446,7 +462,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _extract_response_function_call(
|
def _extract_response_function_call(
|
||||||
self, response_function_call: Union[FunctionCall, FunctionResponse]
|
self, response_function_call: Union[FunctionCall, FunctionResponse]
|
||||||
) -> AssistantPromptMessage.ToolCall:
|
) -> AssistantPromptMessage.ToolCall:
|
||||||
"""
|
"""
|
||||||
Extract function call from response
|
Extract function call from response
|
||||||
@ -471,7 +487,9 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
arguments=str(map_composite_dict),
|
arguments=str(map_composite_dict),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported response_function_call type: {type(response_function_call)}")
|
raise ValueError(
|
||||||
|
f"Unsupported response_function_call type: {type(response_function_call)}"
|
||||||
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=response_function_call.name, type="function", function=function
|
id=response_function_call.name, type="function", function=function
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user