diff --git a/model-providers/model_providers/__init__.py b/model-providers/model_providers/__init__.py index 486e453c..7e45d16f 100644 --- a/model-providers/model_providers/__init__.py +++ b/model-providers/model_providers/__init__.py @@ -1,17 +1,23 @@ from chatchat.configs import MODEL_PLATFORMS + from model_providers.core.model_manager import ModelManager + def _to_custom_provide_configuration(): provider_name_to_provider_records_dict = {} provider_name_to_provider_model_records_dict = {} - return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict + return ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) + # 基于配置管理器创建的模型实例 provider_manager = ModelManager( provider_name_to_provider_records_dict={ - 'openai': { - 'openai_api_key': "sk-4M9LYF", + "openai": { + "openai_api_key": "sk-4M9LYF", } }, - provider_name_to_provider_model_records_dict={} -) \ No newline at end of file + provider_name_to_provider_model_records_dict={}, +) diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py index ef5fab6e..558a3f7a 100644 --- a/model-providers/model_providers/__main__.py +++ b/model-providers/model_providers/__main__.py @@ -1,51 +1,58 @@ import os -from typing import cast, Generator +from typing import Generator, cast from model_providers.core.model_manager import ModelManager -from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta -from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResultChunk, + LLMResultChunkDelta, +) +from model_providers.core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + UserPromptMessage, +) from model_providers.core.model_runtime.entities.model_entities import ModelType -if __name__ == '__main__': +if __name__ == "__main__": # 基于配置管理器创建的模型实例 provider_manager = ModelManager( provider_name_to_provider_records_dict={ - 'openai': { - 'openai_api_key': "sk-4M9LYF", + "openai": { + "openai_api_key": "sk-4M9LYF", } }, - provider_name_to_provider_model_records_dict={} + provider_name_to_provider_model_records_dict={}, ) # # Invoke model - model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4') + model_instance = provider_manager.get_model_instance( + provider="openai", model_type=ModelType.LLM, model="gpt-4" + ) response = model_instance.invoke_llm( - - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'plugin_web_search': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "plugin_web_search": True, }, - stop=['you'], + stop=["you"], stream=True, - user="abc-123" + user="abc-123", ) assert isinstance(response, Generator) - total_message = '' + total_message = "" for chunk in response: assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) total_message += chunk.delta.message.content - assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True + assert ( + len(chunk.delta.message.content) > 0 + if not chunk.delta.finish_reason + else True + ) print(total_message) - assert '参考资料' in total_message + assert "参考资料" in total_message diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index 7e5ef088..56adcce9 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -1,60 +1,58 @@ import asyncio -import os -from typing import Optional, Any, Dict - -from fastapi import (APIRouter, - FastAPI, - HTTPException, - Response, - Request, - status - ) -import logging -from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb import json -import pprint -import tiktoken -from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest, \ - ChatCompletionResponse, ModelList, EmbeddingsResponse, ChatCompletionStreamResponse, FunctionAvailable -from uvicorn import Config, Server -from fastapi.middleware.cors import CORSMiddleware +import logging import multiprocessing as mp +import os +import pprint import threading +from typing import Any, Dict, Optional + +import tiktoken +from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status +from fastapi.middleware.cors import CORSMiddleware from sse_starlette import EventSourceResponse +from uvicorn import Config, Server -from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage +from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb +from model_providers.core.bootstrap.openai_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamResponse, + EmbeddingsRequest, + EmbeddingsResponse, + FunctionAvailable, + ModelList, +) +from model_providers.core.model_runtime.entities.message_entities import ( + UserPromptMessage, +) from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.utils.generic import dictify, jsonify - from model_providers.core.model_runtime.model_providers import model_provider_factory +from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) -async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest): +async def create_stream_chat_completion( + model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest +): try: - - response = model_type_instance.invoke( model=chat_request.model, credentials={ - 'openai_api_key': "sk-", - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], - model_parameters={ - **chat_request.to_model_parameters_dict() + "openai_api_key": "sk-", + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], + model_parameters={**chat_request.to_model_parameters_dict()}, stop=chat_request.stop, stream=chat_request.stream, - user="abc-123" + user="abc-123", ) return response @@ -81,7 +79,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): host = cfg.get("host", "127.0.0.1") port = cfg.get("port", 20000) - logger.info(f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}") + logger.info( + f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}" + ) return cls(host=host, port=port) def serve(self, logging_conf: Optional[dict] = None): @@ -140,8 +140,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): async def list_models(self, request: Request): pass - async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest): - logger.info(f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}") + async def create_embeddings( + self, request: Request, embeddings_request: EmbeddingsRequest + ): + logger.info( + f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" + ) if os.environ["API_KEY"] is None: authorization = request.headers.get("Authorization") authorization = authorization.split("Bearer ")[-1] @@ -171,42 +175,41 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): ) return EmbeddingsResponse(**dictify(response)) - async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest): - logger.info(f"Received chat completion request: {pprint.pformat(chat_request.dict())}") + async def create_chat_completion( + self, request: Request, chat_request: ChatCompletionRequest + ): + logger.info( + f"Received chat completion request: {pprint.pformat(chat_request.dict())}" + ) if os.environ["API_KEY"] is None: authorization = request.headers.get("Authorization") authorization = authorization.split("Bearer ")[-1] else: authorization = os.environ["API_KEY"] - model_provider_factory.get_providers(provider_name='openai') - provider_instance = model_provider_factory.get_provider_instance('openai') + model_provider_factory.get_providers(provider_name="openai") + provider_instance = model_provider_factory.get_provider_instance("openai") model_type_instance = provider_instance.get_model_instance(ModelType.LLM) if chat_request.stream: generator = create_stream_chat_completion(model_type_instance, chat_request) return EventSourceResponse(generator, media_type="text/event-stream") else: - response = model_type_instance.invoke( - model='gpt-4', + model="gpt-4", credentials={ - 'openai_api_key': "sk-", - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') + "openai_api_key": "sk-", + "minimax_api_key": os.environ.get("MINIMAX_API_KEY"), + "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"), }, - prompt_messages=[ - UserPromptMessage( - content='北京今天的天气怎么样' - ) - ], + prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")], model_parameters={ - 'temperature': 0.7, - 'top_p': 1.0, - 'top_k': 1, - 'plugin_web_search': True, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 1, + "plugin_web_search": True, }, - stop=['you'], + stop=["you"], stream=False, - user="abc-123" + user="abc-123", ) chat_response = ChatCompletionResponse(**dictify(response)) @@ -215,15 +218,19 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): def run( - cfg: Dict, logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: import signal + # 跳过键盘中断,使用xoscar的信号处理 signal.signal(signal.SIGINT, lambda *_: None) - api = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg=cfg.get("run_openai_api", {})) + api = RESTFulOpenAIBootstrapBaseWeb.from_config( + cfg=cfg.get("run_openai_api", {}) + ) api.set_app_event(started_event=started_event) api.serve(logging_conf=logging_conf) diff --git a/model-providers/model_providers/core/bootstrap/__init__.py b/model-providers/model_providers/core/bootstrap/__init__.py index 1d45692b..f1da737b 100644 --- a/model-providers/model_providers/core/bootstrap/__init__.py +++ b/model-providers/model_providers/core/bootstrap/__init__.py @@ -1,6 +1,6 @@ - from model_providers.core.bootstrap.base import Bootstrap, OpenAIBootstrapBaseWeb from model_providers.core.bootstrap.bootstrap_register import bootstrap_register + __all__ = [ "bootstrap_register", "Bootstrap", diff --git a/model-providers/model_providers/core/bootstrap/base.py b/model-providers/model_providers/core/bootstrap/base.py index 406a27ce..b2da1a0b 100644 --- a/model-providers/model_providers/core/bootstrap/base.py +++ b/model-providers/model_providers/core/bootstrap/base.py @@ -1,11 +1,13 @@ from abc import abstractmethod from collections import deque + from fastapi import Request class Bootstrap: """最大的任务队列""" + _MAX_ONGOING_TASKS: int = 1 """任务队列""" @@ -37,7 +39,6 @@ class Bootstrap: class OpenAIBootstrapBaseWeb(Bootstrap): - def __init__(self): super().__init__() @@ -46,9 +47,13 @@ class OpenAIBootstrapBaseWeb(Bootstrap): pass @abstractmethod - async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest): + async def create_embeddings( + self, request: Request, embeddings_request: EmbeddingsRequest + ): pass @abstractmethod - async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest): + async def create_chat_completion( + self, request: Request, chat_request: ChatCompletionRequest + ): pass diff --git a/model-providers/model_providers/core/bootstrap/bootstrap_register.py b/model-providers/model_providers/core/bootstrap/bootstrap_register.py index ef78184a..aade63cd 100644 --- a/model-providers/model_providers/core/bootstrap/bootstrap_register.py +++ b/model-providers/model_providers/core/bootstrap/bootstrap_register.py @@ -5,6 +5,7 @@ class BootstrapRegister: """ 注册管理器 """ + mapping = { "bootstrap": {}, } @@ -48,4 +49,3 @@ class BootstrapRegister: bootstrap_register = BootstrapRegister() - diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 690475fd..1d7354cf 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -1,6 +1,7 @@ import time from enum import Enum from typing import Any, Dict, List, Optional, Union + from pydantic import BaseModel, Field, root_validator from typing_extensions import Literal @@ -86,13 +87,15 @@ class ChatCompletionRequest(BaseModel): top_k: Optional[float] = None n: int = 1 max_tokens: Optional[int] = None - stop: Optional[list[str]] = None, + stop: Optional[list[str]] = (None,) stream: Optional[bool] = False def to_model_parameters_dict(self, *args, **kwargs): # 调用父类的to_dict方法,并排除tools字段 helper.dump_model - return super().dict(exclude={'tools','messages','functions','function_call'}, *args, **kwargs) + return super().dict( + exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs + ) class ChatCompletionResponseChoice(BaseModel): diff --git a/model-providers/model_providers/core/entities/agent_entities.py b/model-providers/model_providers/core/entities/agent_entities.py index 0cdf8670..656bf4aa 100644 --- a/model-providers/model_providers/core/entities/agent_entities.py +++ b/model-providers/model_providers/core/entities/agent_entities.py @@ -2,7 +2,7 @@ from enum import Enum class PlanningStrategy(Enum): - ROUTER = 'router' - REACT_ROUTER = 'react_router' - REACT = 'react' - FUNCTION_CALL = 'function_call' + ROUTER = "router" + REACT_ROUTER = "react_router" + REACT = "react" + FUNCTION_CALL = "function_call" diff --git a/model-providers/model_providers/core/entities/application_entities.py b/model-providers/model_providers/core/entities/application_entities.py index c9a72176..9a5e0ff4 100644 --- a/model-providers/model_providers/core/entities/application_entities.py +++ b/model-providers/model_providers/core/entities/application_entities.py @@ -5,7 +5,9 @@ from pydantic import BaseModel from model_providers.core.entities.provider_configuration import ProviderModelBundle from model_providers.core.file.file_obj import FileObj -from model_providers.core.model_runtime.entities.message_entities import PromptMessageRole +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessageRole, +) from model_providers.core.model_runtime.entities.model_entities import AIModelEntity @@ -13,6 +15,7 @@ class ModelConfigEntity(BaseModel): """ Model Config Entity. """ + provider: str model: str model_schema: AIModelEntity @@ -27,6 +30,7 @@ class AdvancedChatMessageEntity(BaseModel): """ Advanced Chat Message Entity. """ + text: str role: PromptMessageRole @@ -35,6 +39,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel): """ Advanced Chat Prompt Template Entity. """ + messages: list[AdvancedChatMessageEntity] @@ -47,6 +52,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel): """ Role Prefix Entity. """ + user: str assistant: str @@ -64,11 +70,12 @@ class PromptTemplateEntity(BaseModel): Prompt Type. 'simple', 'advanced' """ - SIMPLE = 'simple' - ADVANCED = 'advanced' + + SIMPLE = "simple" + ADVANCED = "advanced" @classmethod - def value_of(cls, value: str) -> 'PromptType': + def value_of(cls, value: str) -> "PromptType": """ Get value of given mode. @@ -78,18 +85,21 @@ class PromptTemplateEntity(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt type value {value}') + raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType simple_prompt_template: Optional[str] = None advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None - advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None + advanced_completion_prompt_template: Optional[ + AdvancedCompletionPromptTemplateEntity + ] = None class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. """ + variable: str type: str config: dict[str, Any] = {} @@ -105,11 +115,12 @@ class DatasetRetrieveConfigEntity(BaseModel): Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = 'single' - MULTIPLE = 'multiple' + + SINGLE = "single" + MULTIPLE = "multiple" @classmethod - def value_of(cls, value: str) -> 'RetrieveStrategy': + def value_of(cls, value: str) -> "RetrieveStrategy": """ Get value of given mode. @@ -119,7 +130,7 @@ class DatasetRetrieveConfigEntity(BaseModel): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid retrieve strategy value {value}') + raise ValueError(f"invalid retrieve strategy value {value}") query_variable: Optional[str] = None # Only when app mode is completion @@ -134,6 +145,7 @@ class DatasetEntity(BaseModel): """ Dataset Config Entity. """ + dataset_ids: list[str] retrieve_config: DatasetRetrieveConfigEntity @@ -142,6 +154,7 @@ class SensitiveWordAvoidanceEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + type: str config: dict[str, Any] = {} @@ -150,6 +163,7 @@ class TextToSpeechEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + enabled: bool voice: Optional[str] = None language: Optional[str] = None @@ -159,6 +173,7 @@ class FileUploadEntity(BaseModel): """ File Upload Entity. """ + image_config: Optional[dict[str, Any]] = None @@ -166,6 +181,7 @@ class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ + provider_type: Literal["builtin", "api"] provider_id: str tool_name: str @@ -176,6 +192,7 @@ class AgentPromptEntity(BaseModel): """ Agent Prompt Entity. """ + first_prompt: str next_iteration: str @@ -189,6 +206,7 @@ class AgentScratchpadUnit(BaseModel): """ Action Entity. """ + action_name: str action_input: Union[dict, str] @@ -208,8 +226,9 @@ class AgentEntity(BaseModel): """ Agent Strategy. """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' + + CHAIN_OF_THOUGHT = "chain-of-thought" + FUNCTION_CALLING = "function-calling" provider: str model: str @@ -223,6 +242,7 @@ class AppOrchestrationConfigEntity(BaseModel): """ App Orchestration Config Entity. """ + model_config: ModelConfigEntity prompt_template: PromptTemplateEntity external_data_variables: list[ExternalDataVariableEntity] = [] @@ -244,13 +264,14 @@ class InvokeFrom(Enum): """ Invoke From. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + DEBUGGER = "debugger" @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': + def value_of(cls, value: str) -> "InvokeFrom": """ Get value of given mode. @@ -260,7 +281,7 @@ class InvokeFrom(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid invoke from value {value}') + raise ValueError(f"invalid invoke from value {value}") def to_source(self) -> str: """ @@ -269,21 +290,22 @@ class InvokeFrom(Enum): :return: source """ if self == InvokeFrom.WEB_APP: - return 'web_app' + return "web_app" elif self == InvokeFrom.DEBUGGER: - return 'dev' + return "dev" elif self == InvokeFrom.EXPLORE: - return 'explore_app' + return "explore_app" elif self == InvokeFrom.SERVICE_API: - return 'api' + return "api" - return 'dev' + return "dev" class ApplicationGenerateEntity(BaseModel): """ Application Generate Entity. """ + task_id: str tenant_id: str diff --git a/model-providers/model_providers/core/entities/message_entities.py b/model-providers/model_providers/core/entities/message_entities.py index d9217512..b7ad8172 100644 --- a/model-providers/model_providers/core/entities/message_entities.py +++ b/model-providers/model_providers/core/entities/message_entities.py @@ -1,7 +1,13 @@ import enum from typing import Any, cast -from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage +from langchain.schema import ( + AIMessage, + BaseMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) from pydantic import BaseModel from model_providers.core.model_runtime.entities.message_entities import ( @@ -16,7 +22,7 @@ from model_providers.core.model_runtime.entities.message_entities import ( class PromptMessageFileType(enum.Enum): - IMAGE = 'image' + IMAGE = "image" @staticmethod def value_of(value): @@ -33,8 +39,8 @@ class PromptMessageFile(BaseModel): class ImagePromptMessageFile(PromptMessageFile): class DETAIL(enum.Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageFileType = PromptMessageFileType.IMAGE detail: DETAIL = DETAIL.LOW @@ -55,32 +61,39 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe for file in message.files: if file.type == PromptMessageFileType.IMAGE: file = cast(ImagePromptMessageFile, file) - file_prompt_message_contents.append(ImagePromptMessageContent( - data=file.data, - detail=ImagePromptMessageContent.DETAIL.HIGH - if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW - )) + file_prompt_message_contents.append( + ImagePromptMessageContent( + data=file.data, + detail=ImagePromptMessageContent.DETAIL.HIGH + if file.detail.value == "high" + else ImagePromptMessageContent.DETAIL.LOW, + ) + ) - prompt_message_contents = [TextPromptMessageContent(data=message.content)] + prompt_message_contents = [ + TextPromptMessageContent(data=message.content) + ] prompt_message_contents.extend(file_prompt_message_contents) - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + prompt_messages.append( + UserPromptMessage(content=prompt_message_contents) + ) else: prompt_messages.append(UserPromptMessage(content=message.content)) elif isinstance(message, AIMessage): - message_kwargs = { - 'content': message.content - } + message_kwargs = {"content": message.content} - if 'function_call' in message.additional_kwargs: - message_kwargs['tool_calls'] = [ + if "function_call" in message.additional_kwargs: + message_kwargs["tool_calls"] = [ AssistantPromptMessage.ToolCall( - id=message.additional_kwargs['function_call']['id'], - type='function', + id=message.additional_kwargs["function_call"]["id"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.additional_kwargs['function_call']['name'], - arguments=message.additional_kwargs['function_call']['arguments'] - ) + name=message.additional_kwargs["function_call"]["name"], + arguments=message.additional_kwargs["function_call"][ + "arguments" + ], + ), ) ] @@ -88,12 +101,16 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe elif isinstance(message, SystemMessage): prompt_messages.append(SystemPromptMessage(content=message.content)) elif isinstance(message, FunctionMessage): - prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name)) + prompt_messages.append( + ToolPromptMessage(content=message.content, tool_call_id=message.name) + ) return prompt_messages -def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]: +def prompt_messages_to_lc_messages( + prompt_messages: list[PromptMessage], +) -> list[BaseMessage]: messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): @@ -105,24 +122,24 @@ def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list if isinstance(content, TextPromptMessageContent): message_contents.append(content.data) elif isinstance(content, ImagePromptMessageContent): - message_contents.append({ - 'type': 'image', - 'data': content.data, - 'detail': content.detail.value - }) + message_contents.append( + { + "type": "image", + "data": content.data, + "detail": content.detail.value, + } + ) messages.append(HumanMessage(content=message_contents)) elif isinstance(prompt_message, AssistantPromptMessage): - message_kwargs = { - 'content': prompt_message.content - } + message_kwargs = {"content": prompt_message.content} if prompt_message.tool_calls: - message_kwargs['additional_kwargs'] = { - 'function_call': { - 'id': prompt_message.tool_calls[0].id, - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments + message_kwargs["additional_kwargs"] = { + "function_call": { + "id": prompt_message.tool_calls[0].id, + "name": prompt_message.tool_calls[0].function.name, + "arguments": prompt_message.tool_calls[0].function.arguments, } } @@ -130,6 +147,10 @@ def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list elif isinstance(prompt_message, SystemPromptMessage): messages.append(SystemMessage(content=prompt_message.content)) elif isinstance(prompt_message, ToolPromptMessage): - messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content)) + messages.append( + FunctionMessage( + name=prompt_message.tool_call_id, content=prompt_message.content + ) + ) return messages diff --git a/model-providers/model_providers/core/entities/model_entities.py b/model-providers/model_providers/core/entities/model_entities.py index 2ae3bcd9..20e5dbc9 100644 --- a/model-providers/model_providers/core/entities/model_entities.py +++ b/model-providers/model_providers/core/entities/model_entities.py @@ -4,7 +4,10 @@ from typing import Optional from pydantic import BaseModel from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.model_entities import ModelType, ProviderModel +from model_providers.core.model_runtime.entities.model_entities import ( + ModelType, + ProviderModel, +) from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity @@ -12,6 +15,7 @@ class ModelStatus(Enum): """ Enum class for model status. """ + ACTIVE = "active" NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" @@ -22,6 +26,7 @@ class SimpleModelProviderEntity(BaseModel): """ Simple provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -39,7 +44,7 @@ class SimpleModelProviderEntity(BaseModel): label=provider_entity.label, icon_small=provider_entity.icon_small, icon_large=provider_entity.icon_large, - supported_model_types=provider_entity.supported_model_types + supported_model_types=provider_entity.supported_model_types, ) @@ -47,6 +52,7 @@ class ModelWithProviderEntity(ProviderModel): """ Model with provider entity. """ + provider: SimpleModelProviderEntity status: ModelStatus @@ -55,6 +61,7 @@ class DefaultModelProviderEntity(BaseModel): """ Default model provider entity. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -66,6 +73,7 @@ class DefaultModelEntity(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: DefaultModelProviderEntity diff --git a/model-providers/model_providers/core/entities/provider_configuration.py b/model-providers/model_providers/core/entities/provider_configuration.py index 0b05635c..947a2900 100644 --- a/model-providers/model_providers/core/entities/provider_configuration.py +++ b/model-providers/model_providers/core/entities/provider_configuration.py @@ -7,9 +7,16 @@ from typing import Optional from pydantic import BaseModel -from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity +from model_providers.core.entities.model_entities import ( + ModelStatus, + ModelWithProviderEntity, + SimpleModelProviderEntity, +) from model_providers.core.entities.provider_entities import CustomConfiguration -from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType +from model_providers.core.model_runtime.entities.model_entities import ( + FetchFrom, + ModelType, +) from model_providers.core.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, @@ -18,7 +25,9 @@ from model_providers.core.model_runtime.entities.provider_entities import ( ) from model_providers.core.model_runtime.model_providers import model_provider_factory from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) @@ -27,13 +36,16 @@ class ProviderConfiguration(BaseModel): """ Model class for provider configuration. """ + provider: ProviderEntity custom_configuration: CustomConfiguration def __init__(self, **data): super().__init__(**data) - def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: + def get_current_credentials( + self, model_type: ModelType, model: str + ) -> Optional[dict]: """ Get current credentials. @@ -43,7 +55,10 @@ class ProviderConfiguration(BaseModel): """ if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: + if ( + model_configuration.model_type == model_type + and model_configuration.model == model + ): return model_configuration.credentials if self.custom_configuration.provider: @@ -69,8 +84,9 @@ class ProviderConfiguration(BaseModel): copy_credentials = credentials.copy() return copy_credentials - def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ - -> Optional[dict]: + def get_custom_model_credentials( + self, model_type: ModelType, model: str, obfuscated: bool = False + ) -> Optional[dict]: """ Get custom model credentials. @@ -83,7 +99,10 @@ class ProviderConfiguration(BaseModel): return None for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: + if ( + model_configuration.model_type == model_type + and model_configuration.model == model + ): credentials = model_configuration.credentials if not obfuscated: return credentials @@ -113,9 +132,9 @@ class ProviderConfiguration(BaseModel): # Get model instance of LLM return provider_instance.get_model_instance(model_type) - def get_provider_model(self, model_type: ModelType, - model: str, - only_active: bool = False) -> Optional[ModelWithProviderEntity]: + def get_provider_model( + self, model_type: ModelType, model: str, only_active: bool = False + ) -> Optional[ModelWithProviderEntity]: """ Get provider model. :param model_type: model type @@ -131,8 +150,9 @@ class ProviderConfiguration(BaseModel): return None - def get_provider_models(self, model_type: Optional[ModelType] = None, - only_active: bool = False) -> list[ModelWithProviderEntity]: + def get_provider_models( + self, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type @@ -148,18 +168,19 @@ class ProviderConfiguration(BaseModel): model_types = provider_instance.get_provider_schema().supported_model_types provider_models = self._get_custom_provider_models( - model_types=model_types, - provider_instance=provider_instance + model_types=model_types, provider_instance=provider_instance ) if only_active: - provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] + provider_models = [ + m for m in provider_models if m.status == ModelStatus.ACTIVE + ] # resort provider_models return sorted(provider_models, key=lambda x: x.model_type.value) - def _get_custom_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: + def _get_custom_provider_models( + self, model_types: list[ModelType], provider_instance: ModelProvider + ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -189,7 +210,9 @@ class ProviderConfiguration(BaseModel): model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE + status=ModelStatus.ACTIVE + if credentials + else ModelStatus.NO_CONFIGURE, ) ) @@ -199,15 +222,13 @@ class ProviderConfiguration(BaseModel): continue try: - custom_model_schema = ( - provider_instance.get_model_instance(model_configuration.model_type) - .get_customizable_model_schema_from_credentials( - model_configuration.model, - model_configuration.credentials - ) + custom_model_schema = provider_instance.get_model_instance( + model_configuration.model_type + ).get_customizable_model_schema_from_credentials( + model_configuration.model, model_configuration.credentials ) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -223,7 +244,7 @@ class ProviderConfiguration(BaseModel): model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE + status=ModelStatus.ACTIVE, ) ) @@ -234,16 +255,18 @@ class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ + configurations: dict[str, ProviderConfiguration] = {} def __init__(self): super().__init__() - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - only_active: bool = False) \ - -> list[ModelWithProviderEntity]: + def get_models( + self, + provider: Optional[str] = None, + model_type: Optional[ModelType] = None, + only_active: bool = False, + ) -> list[ModelWithProviderEntity]: """ Get available models. @@ -278,7 +301,9 @@ class ProviderConfigurations(BaseModel): if provider and provider_configuration.provider.provider != provider: continue - all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) + all_models.extend( + provider_configuration.get_provider_models(model_type, only_active) + ) return all_models @@ -310,6 +335,7 @@ class ProviderModelBundle(BaseModel): """ Provider model bundle. """ + configuration: ProviderConfiguration provider_instance: ModelProvider model_type_instance: AIModel diff --git a/model-providers/model_providers/core/entities/provider_entities.py b/model-providers/model_providers/core/entities/provider_entities.py index 715cf899..0f6ebd49 100644 --- a/model-providers/model_providers/core/entities/provider_entities.py +++ b/model-providers/model_providers/core/entities/provider_entities.py @@ -12,11 +12,11 @@ class RestrictModel(BaseModel): model_type: ModelType - class CustomProviderConfiguration(BaseModel): """ Model class for provider custom configuration. """ + credentials: dict @@ -24,6 +24,7 @@ class CustomModelConfiguration(BaseModel): """ Model class for provider custom model configuration. """ + model: str model_type: ModelType credentials: dict @@ -33,5 +34,6 @@ class CustomConfiguration(BaseModel): """ Model class for provider custom configuration. """ + provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] diff --git a/model-providers/model_providers/core/entities/queue_entities.py b/model-providers/model_providers/core/entities/queue_entities.py index cd5e8267..7ba21aa6 100644 --- a/model-providers/model_providers/core/entities/queue_entities.py +++ b/model-providers/model_providers/core/entities/queue_entities.py @@ -3,13 +3,17 @@ from typing import Any from pydantic import BaseModel -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, +) class QueueEvent(Enum): """ QueueEvent enum """ + MESSAGE = "message" AGENT_MESSAGE = "agent_message" MESSAGE_REPLACE = "message-replace" @@ -27,6 +31,7 @@ class AppQueueEvent(BaseModel): """ QueueEvent entity """ + event: QueueEvent @@ -34,21 +39,25 @@ class QueueMessageEvent(AppQueueEvent): """ QueueMessageEvent entity """ + event = QueueEvent.MESSAGE chunk: LLMResultChunk + class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity """ + event = QueueEvent.AGENT_MESSAGE chunk: LLMResultChunk - + class QueueMessageReplaceEvent(AppQueueEvent): """ QueueMessageReplaceEvent entity """ + event = QueueEvent.MESSAGE_REPLACE text: str @@ -57,6 +66,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ QueueRetrieverResourcesEvent entity """ + event = QueueEvent.RETRIEVER_RESOURCES retriever_resources: list[dict] @@ -65,6 +75,7 @@ class AnnotationReplyEvent(AppQueueEvent): """ AnnotationReplyEvent entity """ + event = QueueEvent.ANNOTATION_REPLY message_annotation_id: str @@ -73,28 +84,34 @@ class QueueMessageEndEvent(AppQueueEvent): """ QueueMessageEndEvent entity """ + event = QueueEvent.MESSAGE_END llm_result: LLMResult - + class QueueAgentThoughtEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event = QueueEvent.AGENT_THOUGHT agent_thought_id: str + class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event = QueueEvent.MESSAGE_FILE message_file_id: str - + + class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity """ + event = QueueEvent.ERROR error: Any @@ -103,6 +120,7 @@ class QueuePingEvent(AppQueueEvent): """ QueuePingEvent entity """ + event = QueueEvent.PING @@ -110,10 +128,12 @@ class QueueStopEvent(AppQueueEvent): """ QueueStopEvent entity """ + class StopBy(Enum): """ Stop by enum """ + USER_MANUAL = "user-manual" ANNOTATION_REPLY = "annotation-reply" OUTPUT_MODERATION = "output-moderation" @@ -126,6 +146,7 @@ class QueueMessage(BaseModel): """ QueueMessage entity """ + task_id: str message_id: str conversation_id: str diff --git a/model-providers/model_providers/core/model_manager.py b/model-providers/model_providers/core/model_manager.py index f13a7b6e..5e2c69bd 100644 --- a/model-providers/model_providers/core/model_manager.py +++ b/model-providers/model_providers/core/model_manager.py @@ -2,23 +2,40 @@ from collections.abc import Generator from typing import IO, Optional, Union, cast from model_providers.core.entities.provider_configuration import ProviderModelBundle -from model_providers.errors.error import ProviderTokenNotInitError from model_providers.core.model_runtime.callbacks.base_callback import Callback from model_providers.core.model_runtime.entities.llm_entities import LLMResult -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.rerank_entities import RerankResult -from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel -from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + TextEmbeddingResult, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.__base.moderation_model import ( + ModerationModel, +) +from model_providers.core.model_runtime.model_providers.__base.rerank_model import ( + RerankModel, +) +from model_providers.core.model_runtime.model_providers.__base.speech2text_model import ( + Speech2TextModel, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel from model_providers.core.provider_manager import ProviderManager +from model_providers.errors.error import ProviderTokenNotInitError -def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: +def _fetch_credentials_from_bundle( + provider_model_bundle: ProviderModelBundle, model: str +) -> dict: """ Fetch credentials from provider model bundle :param provider_model_bundle: provider model bundle @@ -26,12 +43,13 @@ def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, m :return: """ credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=provider_model_bundle.model_type_instance.model_type, - model=model + model_type=provider_model_bundle.model_type_instance.model_type, model=model ) if credentials is None: - raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") + raise ProviderTokenNotInitError( + f"Model {model} credentials is not initialized." + ) return credentials @@ -48,10 +66,16 @@ class ModelInstance: self.credentials = _fetch_credentials_from_bundle(provider_model_bundle, model) self.model_type_instance = self._provider_model_bundle.model_type_instance - def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ - -> Union[LLMResult, Generator]: + def invoke_llm( + self, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -77,11 +101,12 @@ class ModelInstance: stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) - def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke_text_embedding( + self, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: """ Invoke large language model @@ -94,16 +119,17 @@ class ModelInstance: self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - texts=texts, - user=user + model=self.model, credentials=self.credentials, texts=texts, user=user ) - def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke_rerank( + self, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -125,11 +151,10 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user + user=user, ) - def invoke_moderation(self, text: str, user: Optional[str] = None) \ - -> bool: + def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -142,14 +167,10 @@ class ModelInstance: self.model_type_instance = cast(ModerationModel, self.model_type_instance) return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - text=text, - user=user + model=self.model, credentials=self.credentials, text=text, user=user ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -162,14 +183,17 @@ class ModelInstance: self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) return self.model_type_instance.invoke( - model=self.model, - credentials=self.credentials, - file=file, - user=user + model=self.model, credentials=self.credentials, file=file, user=user ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \ - -> str: + def invoke_tts( + self, + content_text: str, + tenant_id: str, + voice: str, + streaming: bool, + user: Optional[str] = None, + ) -> str: """ Invoke large language tts model @@ -191,7 +215,7 @@ class ModelInstance: user=user, tenant_id=tenant_id, voice=voice, - streaming=streaming + streaming=streaming, ) def get_tts_voices(self, language: str) -> list: @@ -206,21 +230,24 @@ class ModelInstance: self.model_type_instance = cast(TTSModel, self.model_type_instance) return self.model_type_instance.get_tts_model_voices( - model=self.model, - credentials=self.credentials, - language=language + model=self.model, credentials=self.credentials, language=language ) class ModelManager: - def __init__(self, - provider_name_to_provider_records_dict: dict, - provider_name_to_provider_model_records_dict: dict) -> None: + def __init__( + self, + provider_name_to_provider_records_dict: dict, + provider_name_to_provider_model_records_dict: dict, + ) -> None: self._provider_manager = ProviderManager( provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, - provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict) + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) - def get_model_instance(self, provider: str, model_type: ModelType, model: str) -> ModelInstance: + def get_model_instance( + self, provider: str, model_type: ModelType, model: str + ) -> ModelInstance: """ Get model instance :param provider: provider name @@ -231,8 +258,7 @@ class ModelManager: if not provider: return self.get_default_model_instance(model_type) provider_model_bundle = self._provider_manager.get_provider_model_bundle( - provider=provider, - model_type=model_type + provider=provider, model_type=model_type ) return ModelInstance(provider_model_bundle, model) @@ -253,5 +279,5 @@ class ModelManager: return self.get_model_instance( provider=default_model_entity.provider.provider, model_type=model_type, - model=default_model_entity.model + model=default_model_entity.model, ) diff --git a/model-providers/model_providers/core/model_runtime/callbacks/base_callback.py b/model-providers/model_providers/core/model_runtime/callbacks/base_callback.py index a5103ab3..f7b8c3e5 100644 --- a/model-providers/model_providers/core/model_runtime/callbacks/base_callback.py +++ b/model-providers/model_providers/core/model_runtime/callbacks/base_callback.py @@ -1,8 +1,14 @@ from abc import ABC from typing import Optional -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, +) +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel _TEXT_COLOR_MAPPING = { @@ -19,12 +25,21 @@ class Callback(ABC): Base class for callbacks. Only for LLM. """ + raise_error: bool = False - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -40,10 +55,19 @@ class Callback(ABC): """ raise NotImplementedError() - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -60,10 +84,19 @@ class Callback(ABC): """ raise NotImplementedError() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -80,10 +113,19 @@ class Callback(ABC): """ raise NotImplementedError() - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -100,9 +142,7 @@ class Callback(ABC): """ raise NotImplementedError() - def print_text( - self, text: str, color: Optional[str] = None, end: str = "" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/model-providers/model_providers/core/model_runtime/callbacks/logging_callback.py b/model-providers/model_providers/core/model_runtime/callbacks/logging_callback.py index 1ee0b740..be78e354 100644 --- a/model-providers/model_providers/core/model_runtime/callbacks/logging_callback.py +++ b/model-providers/model_providers/core/model_runtime/callbacks/logging_callback.py @@ -4,17 +4,32 @@ import sys from typing import Optional from model_providers.core.model_runtime.callbacks.base_callback import Callback -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, +) +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) + class LoggingCallback(Callback): - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -28,40 +43,49 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_before_invoke]\n", color='blue') - self.print_text(f"Model: {model}\n", color='blue') - self.print_text("Parameters:\n", color='blue') + self.print_text("\n[on_llm_before_invoke]\n", color="blue") + self.print_text(f"Model: {model}\n", color="blue") + self.print_text("Parameters:\n", color="blue") for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color='blue') + self.print_text(f"\t{key}: {value}\n", color="blue") if stop: - self.print_text(f"\tstop: {stop}\n", color='blue') + self.print_text(f"\tstop: {stop}\n", color="blue") if tools: - self.print_text("\tTools:\n", color='blue') + self.print_text("\tTools:\n", color="blue") for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color='blue') + self.print_text(f"\t\t{tool.name}\n", color="blue") - self.print_text(f"Stream: {stream}\n", color='blue') + self.print_text(f"Stream: {stream}\n", color="blue") if user: - self.print_text(f"User: {user}\n", color='blue') + self.print_text(f"User: {user}\n", color="blue") - self.print_text("Prompt messages:\n", color='blue') + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color='blue') + self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue') - self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue') + self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") + self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") if stream: self.print_text("\n[on_llm_new_chunk]") - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -79,10 +103,19 @@ class LoggingCallback(Callback): sys.stdout.write(chunk.delta.message.content) sys.stdout.flush() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -97,24 +130,37 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_after_invoke]\n", color='yellow') - self.print_text(f"Content: {result.message.content}\n", color='yellow') + self.print_text("\n[on_llm_after_invoke]\n", color="yellow") + self.print_text(f"Content: {result.message.content}\n", color="yellow") if result.message.tool_calls: - self.print_text("Tool calls:\n", color='yellow') + self.print_text("Tool calls:\n", color="yellow") for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color='yellow') - self.print_text(f"\t{tool_call.function.name}\n", color='yellow') - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow') + self.print_text(f"\t{tool_call.id}\n", color="yellow") + self.print_text(f"\t{tool_call.function.name}\n", color="yellow") + self.print_text( + f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow" + ) - self.print_text(f"Model: {result.model}\n", color='yellow') - self.print_text(f"Usage: {result.usage}\n", color='yellow') - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow') + self.print_text(f"Model: {result.model}\n", color="yellow") + self.print_text(f"Usage: {result.usage}\n", color="yellow") + self.print_text( + f"System Fingerprint: {result.system_fingerprint}\n", color="yellow" + ) - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -129,5 +175,5 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_invoke_error]\n", color='red') + self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/model-providers/model_providers/core/model_runtime/entities/common_entities.py b/model-providers/model_providers/core/model_runtime/entities/common_entities.py index 175c13cf..659ad59b 100644 --- a/model-providers/model_providers/core/model_runtime/entities/common_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None en_US: str diff --git a/model-providers/model_providers/core/model_runtime/entities/defaults.py b/model-providers/model_providers/core/model_runtime/entities/defaults.py index 438aaa3d..98719aac 100644 --- a/model-providers/model_providers/core/model_runtime/entities/defaults.py +++ b/model-providers/model_providers/core/model_runtime/entities/defaults.py @@ -1,98 +1,99 @@ - -from model_providers.core.model_runtime.entities.model_entities import DefaultParameterName +from model_providers.core.model_runtime.entities.model_entities import ( + DefaultParameterName, +) PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { - 'label': { - 'en_US': 'Temperature', - 'zh_Hans': '温度', + "label": { + "en_US": "Temperature", + "zh_Hans": "温度", }, - 'type': 'float', - 'help': { - 'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.', - 'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。', + "type": "float", + "help": { + "en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.", + "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_P: { - 'label': { - 'en_US': 'Top P', - 'zh_Hans': 'Top P', + "label": { + "en_US": "Top P", + "zh_Hans": "Top P", }, - 'type': 'float', - 'help': { - 'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.', - 'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。', + "type": "float", + "help": { + "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.", + "zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。", }, - 'required': False, - 'default': 1.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 1.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.PRESENCE_PENALTY: { - 'label': { - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', + "label": { + "en_US": "Presence Penalty", + "zh_Hans": "存在惩罚", }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens already in the text.', - 'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。', + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens already in the text.", + "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.FREQUENCY_PENALTY: { - 'label': { - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', + "label": { + "en_US": "Frequency Penalty", + "zh_Hans": "频率惩罚", }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.', - 'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。', + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", + "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.MAX_TOKENS: { - 'label': { - 'en_US': 'Max Tokens', - 'zh_Hans': '最大标记', + "label": { + "en_US": "Max Tokens", + "zh_Hans": "最大标记", }, - 'type': 'int', - 'help': { - 'en_US': 'The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.', - 'zh_Hans': '要生成的标记的最大数量。请求可以使用最多2048个标记,这些标记在提示和完成之间共享。', + "type": "int", + "help": { + "en_US": "The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.", + "zh_Hans": "要生成的标记的最大数量。请求可以使用最多2048个标记,这些标记在提示和完成之间共享。", }, - 'required': False, - 'default': 64, - 'min': 1, - 'max': 2048, - 'precision': 0, + "required": False, + "default": 64, + "min": 1, + "max": 2048, + "precision": 0, }, DefaultParameterName.RESPONSE_FORMAT: { - 'label': { - 'en_US': 'Response Format', - 'zh_Hans': '回复格式', + "label": { + "en_US": "Response Format", + "zh_Hans": "回复格式", }, - 'type': 'string', - 'help': { - 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.', - 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等', + "type": "string", + "help": { + "en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.", + "zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等", }, - 'required': False, - 'options': ['JSON', 'XML'], - } + "required": False, + "options": ["JSON", "XML"], + }, } diff --git a/model-providers/model_providers/core/model_runtime/entities/llm_entities.py b/model-providers/model_providers/core/model_runtime/entities/llm_entities.py index 99f37500..eafdfb2b 100644 --- a/model-providers/model_providers/core/model_runtime/entities/llm_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/llm_entities.py @@ -4,19 +4,26 @@ from typing import Optional from pydantic import BaseModel -from model_providers.core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from model_providers.core.model_runtime.entities.model_entities import ModelUsage, PriceInfo +from model_providers.core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, +) +from model_providers.core.model_runtime.entities.model_entities import ( + ModelUsage, + PriceInfo, +) class LLMMode(Enum): """ Enum class for large language model mode. """ + COMPLETION = "completion" CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'LLMMode': + def value_of(cls, value: str) -> "LLMMode": """ Get value of given mode. @@ -26,13 +33,14 @@ class LLMMode(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class LLMUsage(ModelUsage): """ Model class for llm usage. """ + prompt_tokens: int prompt_unit_price: Decimal prompt_price_unit: Decimal @@ -50,17 +58,17 @@ class LLMUsage(ModelUsage): def empty_usage(cls): return cls( prompt_tokens=0, - prompt_unit_price=Decimal('0.0'), - prompt_price_unit=Decimal('0.0'), - prompt_price=Decimal('0.0'), + prompt_unit_price=Decimal("0.0"), + prompt_price_unit=Decimal("0.0"), + prompt_price=Decimal("0.0"), completion_tokens=0, - completion_unit_price=Decimal('0.0'), - completion_price_unit=Decimal('0.0'), - completion_price=Decimal('0.0'), + completion_unit_price=Decimal("0.0"), + completion_price_unit=Decimal("0.0"), + completion_price=Decimal("0.0"), total_tokens=0, - total_price=Decimal('0.0'), - currency='USD', - latency=0.0 + total_price=Decimal("0.0"), + currency="USD", + latency=0.0, ) @@ -68,6 +76,7 @@ class LLMResult(BaseModel): """ Model class for llm result. """ + model: str prompt_messages: list[PromptMessage] message: AssistantPromptMessage @@ -79,6 +88,7 @@ class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. """ + index: int message: AssistantPromptMessage usage: Optional[LLMUsage] = None @@ -89,6 +99,7 @@ class LLMResultChunk(BaseModel): """ Model class for llm result chunk. """ + model: str prompt_messages: list[PromptMessage] system_fingerprint: Optional[str] = None @@ -99,4 +110,5 @@ class NumTokensResult(PriceInfo): """ Model class for number of tokens result. """ + tokens: int diff --git a/model-providers/model_providers/core/model_runtime/entities/message_entities.py b/model-providers/model_providers/core/model_runtime/entities/message_entities.py index 83b12082..c9a823c0 100644 --- a/model-providers/model_providers/core/model_runtime/entities/message_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/message_entities.py @@ -9,13 +9,14 @@ class PromptMessageRole(Enum): """ Enum class for prompt message. """ + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" @classmethod - def value_of(cls, value: str) -> 'PromptMessageRole': + def value_of(cls, value: str) -> "PromptMessageRole": """ Get value of given mode. @@ -25,13 +26,14 @@ class PromptMessageRole(Enum): for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt message type value {value}') + raise ValueError(f"invalid prompt message type value {value}") class PromptMessageTool(BaseModel): """ Model class for prompt message tool. """ + name: str description: str parameters: dict @@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel): """ Model class for prompt message function. """ - type: str = 'function' + + type: str = "function" function: PromptMessageTool @@ -49,14 +52,16 @@ class PromptMessageContentType(Enum): """ Enum class for prompt message content type. """ - TEXT = 'text' - IMAGE = 'image' + + TEXT = "text" + IMAGE = "image" class PromptMessageContent(BaseModel): """ Model class for prompt message content. """ + type: PromptMessageContentType data: str @@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent): """ Model class for text prompt message content. """ + type: PromptMessageContentType = PromptMessageContentType.TEXT @@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ + class DETAIL(Enum): - LOW = 'low' - HIGH = 'high' + LOW = "low" + HIGH = "high" type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW @@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ + role: PromptMessageRole content: Optional[str | list[PromptMessageContent]] = None name: Optional[str] = None @@ -93,6 +101,7 @@ class UserPromptMessage(PromptMessage): """ Model class for user prompt message. """ + role: PromptMessageRole = PromptMessageRole.USER @@ -100,14 +109,17 @@ class AssistantPromptMessage(PromptMessage): """ Model class for assistant prompt message. """ + class ToolCall(BaseModel): """ Model class for assistant prompt message tool call. """ + class ToolCallFunction(BaseModel): """ Model class for assistant prompt message tool call function. """ + name: str arguments: str @@ -123,6 +135,7 @@ class SystemPromptMessage(PromptMessage): """ Model class for system prompt message. """ + role: PromptMessageRole = PromptMessageRole.SYSTEM @@ -130,5 +143,6 @@ class ToolPromptMessage(PromptMessage): """ Model class for tool prompt message. """ + role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str diff --git a/model-providers/model_providers/core/model_runtime/entities/model_entities.py b/model-providers/model_providers/core/model_runtime/entities/model_entities.py index 50b822d5..307d459a 100644 --- a/model-providers/model_providers/core/model_runtime/entities/model_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/model_entities.py @@ -11,6 +11,7 @@ class ModelType(Enum): """ Enum class for model type. """ + LLM = "llm" TEXT_EMBEDDING = "text-embedding" RERANK = "rerank" @@ -26,22 +27,28 @@ class ModelType(Enum): :return: model type """ - if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value: + if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: return cls.LLM - elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value: + elif ( + origin_model_type == "embeddings" + or origin_model_type == cls.TEXT_EMBEDDING.value + ): return cls.TEXT_EMBEDDING - elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: + elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: return cls.RERANK - elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: + elif ( + origin_model_type == "speech2text" + or origin_model_type == cls.SPEECH2TEXT.value + ): return cls.SPEECH2TEXT - elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: + elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: return cls.TTS - elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION else: - raise ValueError(f'invalid origin model type {origin_model_type}') + raise ValueError(f"invalid origin model type {origin_model_type}") def to_origin_model_type(self) -> str: """ @@ -50,26 +57,28 @@ class ModelType(Enum): :return: origin model type """ if self == self.LLM: - return 'text-generation' + return "text-generation" elif self == self.TEXT_EMBEDDING: - return 'embeddings' + return "embeddings" elif self == self.RERANK: - return 'reranking' + return "reranking" elif self == self.SPEECH2TEXT: - return 'speech2text' + return "speech2text" elif self == self.TTS: - return 'tts' + return "tts" elif self == self.MODERATION: - return 'moderation' + return "moderation" elif self == self.TEXT2IMG: - return 'text2img' + return "text2img" else: - raise ValueError(f'invalid model type {self}') + raise ValueError(f"invalid model type {self}") + class FetchFrom(Enum): """ Enum class for fetch from. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -78,6 +87,7 @@ class ModelFeature(Enum): """ Enum class for llm feature. """ + TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" @@ -89,6 +99,7 @@ class DefaultParameterName(Enum): """ Enum class for parameter template variable. """ + TEMPERATURE = "temperature" TOP_P = "top_p" PRESENCE_PENALTY = "presence_penalty" @@ -97,7 +108,7 @@ class DefaultParameterName(Enum): RESPONSE_FORMAT = "response_format" @classmethod - def value_of(cls, value: Any) -> 'DefaultParameterName': + def value_of(cls, value: Any) -> "DefaultParameterName": """ Get parameter name from value. @@ -107,13 +118,14 @@ class DefaultParameterName(Enum): for name in cls: if name.value == value: return name - raise ValueError(f'invalid parameter name {value}') + raise ValueError(f"invalid parameter name {value}") class ParameterType(Enum): """ Enum class for parameter type. """ + FLOAT = "float" INT = "int" STRING = "string" @@ -124,6 +136,7 @@ class ModelPropertyKey(Enum): """ Enum class for model property key. """ + MODE = "mode" CONTEXT_SIZE = "context_size" MAX_CHUNKS = "max_chunks" @@ -141,6 +154,7 @@ class ProviderModel(BaseModel): """ Model class for provider model. """ + model: str label: I18nObject model_type: ModelType @@ -157,6 +171,7 @@ class ParameterRule(BaseModel): """ Model class for parameter rule. """ + name: str use_template: Optional[str] = None label: I18nObject @@ -174,6 +189,7 @@ class PriceConfig(BaseModel): """ Model class for pricing info. """ + input: Decimal output: Optional[Decimal] = None unit: Decimal @@ -184,6 +200,7 @@ class AIModelEntity(ProviderModel): """ Model class for AI model. """ + parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None @@ -196,6 +213,7 @@ class PriceType(Enum): """ Enum class for price type. """ + INPUT = "input" OUTPUT = "output" @@ -204,6 +222,7 @@ class PriceInfo(BaseModel): """ Model class for price info. """ + unit_price: Decimal unit: Decimal total_amount: Decimal diff --git a/model-providers/model_providers/core/model_runtime/entities/provider_entities.py b/model-providers/model_providers/core/model_runtime/entities/provider_entities.py index e71ca0f1..21b610ad 100644 --- a/model-providers/model_providers/core/model_runtime/entities/provider_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/provider_entities.py @@ -4,13 +4,18 @@ from typing import Optional from pydantic import BaseModel from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelType, + ProviderModel, +) class ConfigurateMethod(Enum): """ Enum class for configurate method of provider model. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -19,6 +24,7 @@ class FormType(Enum): """ Enum class for form type. """ + TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" SELECT = "select" @@ -30,6 +36,7 @@ class FormShowOnObject(BaseModel): """ Model class for form show on. """ + variable: str value: str @@ -38,6 +45,7 @@ class FormOption(BaseModel): """ Model class for form option. """ + label: I18nObject value: str show_on: list[FormShowOnObject] = [] @@ -45,15 +53,14 @@ class FormOption(BaseModel): def __init__(self, **data): super().__init__(**data) if not self.label: - self.label = I18nObject( - en_US=self.value - ) + self.label = I18nObject(en_US=self.value) class CredentialFormSchema(BaseModel): """ Model class for credential form schema. """ + variable: str label: I18nObject type: FormType @@ -69,6 +76,7 @@ class ProviderCredentialSchema(BaseModel): """ Model class for provider credential schema. """ + credential_form_schemas: list[CredentialFormSchema] @@ -81,6 +89,7 @@ class ModelCredentialSchema(BaseModel): """ Model class for model credential schema. """ + model: FieldModelSchema credential_form_schemas: list[CredentialFormSchema] @@ -89,6 +98,7 @@ class SimpleProviderEntity(BaseModel): """ Simple model class for provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -101,6 +111,7 @@ class ProviderHelpEntity(BaseModel): """ Model class for provider help. """ + title: I18nObject url: I18nObject @@ -109,6 +120,7 @@ class ProviderEntity(BaseModel): """ Model class for provider. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -137,7 +149,7 @@ class ProviderEntity(BaseModel): icon_small=self.icon_small, icon_large=self.icon_large, supported_model_types=self.supported_model_types, - models=self.models + models=self.models, ) @@ -145,5 +157,6 @@ class ProviderConfig(BaseModel): """ Model class for provider config. """ + provider: str credentials: dict diff --git a/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py b/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py index d51efd2b..99709e1b 100644 --- a/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py @@ -5,6 +5,7 @@ class RerankDocument(BaseModel): """ Model class for rerank document. """ + index: int text: str score: float @@ -14,5 +15,6 @@ class RerankResult(BaseModel): """ Model class for rerank result. """ + model: str docs: list[RerankDocument] diff --git a/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py b/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py index a67c63b6..fa2172a0 100644 --- a/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py @@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. """ + tokens: int total_tokens: int unit_price: Decimal @@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel): """ Model class for text embedding result. """ + model: str embeddings: list[list[float]] usage: EmbeddingUsage - diff --git a/model-providers/model_providers/core/model_runtime/errors/invoke.py b/model-providers/model_providers/core/model_runtime/errors/invoke.py index 0513cfaf..edfb19c7 100644 --- a/model-providers/model_providers/core/model_runtime/errors/invoke.py +++ b/model-providers/model_providers/core/model_runtime/errors/invoke.py @@ -3,6 +3,7 @@ from typing import Optional class InvokeError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -14,24 +15,29 @@ class InvokeError(Exception): class InvokeConnectionError(InvokeError): """Raised when the Invoke returns connection error.""" + description = "Connection Error" class InvokeServerUnavailableError(InvokeError): """Raised when the Invoke returns server unavailable error.""" + description = "Server Unavailable Error" class InvokeRateLimitError(InvokeError): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" class InvokeAuthorizationError(InvokeError): """Raised when the Invoke returns authorization error.""" + description = "Incorrect model credentials provided, please check and try again. " class InvokeBadRequestError(InvokeError): """Raised when the Invoke returns bad request.""" + description = "Bad Request Error" diff --git a/model-providers/model_providers/core/model_runtime/errors/validate.py b/model-providers/model_providers/core/model_runtime/errors/validate.py index 8db79a52..7fcd2133 100644 --- a/model-providers/model_providers/core/model_runtime/errors/validate.py +++ b/model-providers/model_providers/core/model_runtime/errors/validate.py @@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception): """ Credentials validate failed error """ + pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/ai_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/ai_model.py index e26686bb..2c3233d4 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/ai_model.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/ai_model.py @@ -16,15 +16,24 @@ from model_providers.core.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) -from model_providers.core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from model_providers.core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer -from model_providers.core.utils.position_helper import get_position_map, sort_by_position_map +from model_providers.core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeError, +) +from model_providers.core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import ( + GPT2Tokenizer, +) +from model_providers.core.utils.position_helper import ( + get_position_map, + sort_by_position_map, +) class AIModel(ABC): """ Base class for all models. """ + model_type: ModelType model_schemas: list[AIModelEntity] = None started_at: float = 0 @@ -60,18 +69,24 @@ class AIModel(ABC): :param error: model invoke error :return: unified error """ - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] for invoke_error, model_errors in self._invoke_error_mapping.items(): if isinstance(error, tuple(model_errors)): if invoke_error == InvokeAuthorizationError: - return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ") + return invoke_error( + description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. " + ) - return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}") + return invoke_error( + description=f"[{provider_name}] {invoke_error.description}, {str(error)}" + ) return InvokeError(description=f"[{provider_name}] Error: {str(error)}") - def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: + def get_price( + self, model: str, credentials: dict, price_type: PriceType, tokens: int + ) -> PriceInfo: """ Get price for given model and tokens @@ -99,15 +114,17 @@ class AIModel(ABC): if unit_price is None: return PriceInfo( - unit_price=decimal.Decimal('0.0'), - unit=decimal.Decimal('0.0'), - total_amount=decimal.Decimal('0.0'), + unit_price=decimal.Decimal("0.0"), + unit=decimal.Decimal("0.0"), + total_amount=decimal.Decimal("0.0"), currency="USD", ) # calculate total amount total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + total_amount = total_amount.quantize( + decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP + ) return PriceInfo( unit_price=unit_price, @@ -128,24 +145,28 @@ class AIModel(ABC): model_schemas = [] # get module name - model_type = self.__class__.__module__.split('.')[-1] + model_type = self.__class__.__module__.split(".")[-1] # get provider name - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] # get the path of current classes current_path = os.path.abspath(__file__) # get parent path of the current path - provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type) + provider_model_type_path = os.path.join( + os.path.dirname(os.path.dirname(current_path)), provider_name, model_type + ) # get all yaml files path under provider_model_type_path that do not start with __ model_schema_yaml_paths = [ os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) - if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') - and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + if not model_schema_yaml.startswith("__") + and not model_schema_yaml.startswith("_") + and os.path.isfile( + os.path.join(provider_model_type_path, model_schema_yaml) + ) + and model_schema_yaml.endswith(".yaml") ] # get _position.yaml file path @@ -154,59 +175,73 @@ class AIModel(ABC): # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: # read yaml data from yaml file - with open(model_schema_yaml_path, encoding='utf-8') as f: + with open(model_schema_yaml_path, encoding="utf-8") as f: yaml_data = yaml.safe_load(f) new_parameter_rules = [] - for parameter_rule in yaml_data.get('parameter_rules', []): - if 'use_template' in parameter_rule: + for parameter_rule in yaml_data.get("parameter_rules", []): + if "use_template" in parameter_rule: try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template']) - default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) + default_parameter_name = DefaultParameterName.value_of( + parameter_rule["use_template"] + ) + default_parameter_rule = ( + self._get_default_parameter_rule_variable_map( + default_parameter_name + ) + ) copy_default_parameter_rule = default_parameter_rule.copy() copy_default_parameter_rule.update(parameter_rule) parameter_rule = copy_default_parameter_rule except ValueError: pass - if 'label' not in parameter_rule: - parameter_rule['label'] = { - 'zh_Hans': parameter_rule['name'], - 'en_US': parameter_rule['name'] + if "label" not in parameter_rule: + parameter_rule["label"] = { + "zh_Hans": parameter_rule["name"], + "en_US": parameter_rule["name"], } new_parameter_rules.append(parameter_rule) - yaml_data['parameter_rules'] = new_parameter_rules + yaml_data["parameter_rules"] = new_parameter_rules - if 'label' not in yaml_data: - yaml_data['label'] = { - 'zh_Hans': yaml_data['model'], - 'en_US': yaml_data['model'] + if "label" not in yaml_data: + yaml_data["label"] = { + "zh_Hans": yaml_data["model"], + "en_US": yaml_data["model"], } - yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value + yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value try: # yaml_data to entity model_schema = AIModelEntity(**yaml_data) except Exception as e: - model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml") - raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:' - f' {str(e)}') + model_schema_yaml_file_name = os.path.basename( + model_schema_yaml_path + ).rstrip(".yaml") + raise Exception( + f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:" + f" {str(e)}" + ) # cache model schema model_schemas.append(model_schema) # resort model schemas by position - model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model) + model_schemas = sort_by_position_map( + position_map, model_schemas, lambda x: x.model + ) # cache model schemas self.model_schemas = model_schemas return model_schemas - def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: + def get_model_schema( + self, model: str, credentials: Optional[dict] = None + ) -> Optional[AIModelEntity]: """ Get model schema by model name and credentials @@ -222,13 +257,17 @@ class AIModel(ABC): return model_map[model] if credentials: - model_schema = self.get_customizable_model_schema_from_credentials(model, credentials) + model_schema = self.get_customizable_model_schema_from_credentials( + model, credentials + ) if model_schema: return model_schema return None - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -238,7 +277,9 @@ class AIModel(ABC): """ return self._get_customizable_model_schema(model, credentials) - def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def _get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: """ Get customizable model schema and fill in the template """ @@ -252,26 +293,51 @@ class AIModel(ABC): for parameter_rule in schema.parameter_rules: if parameter_rule.use_template: try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) - default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and 'max' in default_parameter_rule: - parameter_rule.max = default_parameter_rule['max'] - if not parameter_rule.min and 'min' in default_parameter_rule: - parameter_rule.min = default_parameter_rule['min'] - if not parameter_rule.default and 'default' in default_parameter_rule: - parameter_rule.default = default_parameter_rule['default'] - if not parameter_rule.precision and 'precision' in default_parameter_rule: - parameter_rule.precision = default_parameter_rule['precision'] - if not parameter_rule.required and 'required' in default_parameter_rule: - parameter_rule.required = default_parameter_rule['required'] - if not parameter_rule.help and 'help' in default_parameter_rule: - parameter_rule.help = I18nObject( - en_US=default_parameter_rule['help']['en_US'], + default_parameter_name = DefaultParameterName.value_of( + parameter_rule.use_template + ) + default_parameter_rule = ( + self._get_default_parameter_rule_variable_map( + default_parameter_name ) - if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']): - parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] - if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']): - parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) + ) + if not parameter_rule.max and "max" in default_parameter_rule: + parameter_rule.max = default_parameter_rule["max"] + if not parameter_rule.min and "min" in default_parameter_rule: + parameter_rule.min = default_parameter_rule["min"] + if ( + not parameter_rule.default + and "default" in default_parameter_rule + ): + parameter_rule.default = default_parameter_rule["default"] + if ( + not parameter_rule.precision + and "precision" in default_parameter_rule + ): + parameter_rule.precision = default_parameter_rule["precision"] + if ( + not parameter_rule.required + and "required" in default_parameter_rule + ): + parameter_rule.required = default_parameter_rule["required"] + if not parameter_rule.help and "help" in default_parameter_rule: + parameter_rule.help = I18nObject( + en_US=default_parameter_rule["help"]["en_US"], + ) + if not parameter_rule.help.en_US and ( + "help" in default_parameter_rule + and "en_US" in default_parameter_rule["help"] + ): + parameter_rule.help.en_US = default_parameter_rule["help"][ + "en_US" + ] + if not parameter_rule.help.zh_Hans and ( + "help" in default_parameter_rule + and "zh_Hans" in default_parameter_rule["help"] + ): + parameter_rule.help.zh_Hans = default_parameter_rule[ + "help" + ].get("zh_Hans", default_parameter_rule["help"]["en_US"]) except ValueError: pass @@ -281,7 +347,9 @@ class AIModel(ABC): return schema - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: """ Get customizable model schema @@ -291,7 +359,9 @@ class AIModel(ABC): """ return None - def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict: + def _get_default_parameter_rule_variable_map( + self, name: DefaultParameterName + ) -> dict: """ Get default parameter rule for given name @@ -301,7 +371,7 @@ class AIModel(ABC): default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) if not default_parameter_rule: - raise Exception(f'Invalid model parameter rule name {name}') + raise Exception(f"Invalid model parameter rule name {name}") return default_parameter_rule diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/large_language_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/large_language_model.py index 72b72f1b..2b0a1e20 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/large_language_model.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/large_language_model.py @@ -7,8 +7,16 @@ from collections.abc import Generator from typing import Optional, Union from model_providers.core.model_runtime.callbacks.base_callback import Callback -from model_providers.core.model_runtime.callbacks.logging_callback import LoggingCallback -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from model_providers.core.model_runtime.callbacks.logging_callback import ( + LoggingCallback, +) +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -32,13 +40,21 @@ class LargeLanguageModel(AIModel): """ Model class for large language model. """ + model_type: ModelType = ModelType.LLM - def invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ - -> Union[LLMResult, Generator]: + def invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -57,7 +73,9 @@ class LargeLanguageModel(AIModel): if model_parameters is None: model_parameters = {} - model_parameters = self._validate_and_filter_model_parameters(model, model_parameters, credentials) + model_parameters = self._validate_and_filter_model_parameters( + model, model_parameters, credentials + ) self.started_at = time.perf_counter() @@ -76,7 +94,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) try: @@ -90,10 +108,19 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) else: - result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + result = self._invoke( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) except Exception as e: self._trigger_invoke_error_callbacks( model=model, @@ -105,7 +132,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) raise self._transform_invoke_error(e) @@ -121,7 +148,7 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) else: self._trigger_after_invoke_callbacks( @@ -134,15 +161,23 @@ class LargeLanguageModel(AIModel): stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) return result - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper, ensure the response is a code block with output markdown quote @@ -177,36 +212,44 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - + model_parameters.pop("response_format") stop = stop or [] stop.extend(["\n```", "```\n"]) block_prompts = block_prompts.replace("{{block}}", code_block) # check if there is a system message - if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + if len(prompt_messages) > 0 and isinstance( + prompt_messages[0], SystemPromptMessage + ): # override the system message prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", prompt_messages[0].content) + content=block_prompts.replace( + "{{instructions}}", prompt_messages[0].content + ) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} object.") - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=block_prompts.replace( + "{{instructions}}", + f"Please output a valid {code_block} object.", + ) + ), + ) - if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + if len(prompt_messages) > 0 and isinstance( + prompt_messages[-1], UserPromptMessage + ): # add ```JSON\n to the last message prompt_messages[-1].content += f"\n```{code_block}\n" else: # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) + prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n")) response = self._invoke( model=model, @@ -216,33 +259,40 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if isinstance(response, Generator): first_chunk = next(response) + def new_generator(): yield first_chunk yield from response - if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"): + if ( + first_chunk.delta.message.content + and first_chunk.delta.message.content.startswith("`") + ): return self._code_block_mode_stream_processor_with_backtick( model=model, prompt_messages=prompt_messages, - input_generator=new_generator() + input_generator=new_generator(), ) else: return self._code_block_mode_stream_processor( model=model, prompt_messages=prompt_messages, - input_generator=new_generator() + input_generator=new_generator(), ) - + return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], - input_generator: Generator[LLMResultChunk, None, None] - ) -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor( + self, + model: str, + prompt_messages: list[PromptMessage], + input_generator: Generator[LLMResultChunk, None, None], + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote @@ -291,15 +341,17 @@ if you are not sure about the structure. delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] + content=new_piece, tool_calls=[] ), - ) + ), ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, - input_generator: Generator[LLMResultChunk, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor_with_backtick( + self, + model: str, + prompt_messages: list, + input_generator: Generator[LLMResultChunk, None, None], + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote. This version skips the language identifier that follows the opening triple backticks. @@ -366,26 +418,31 @@ if you are not sure about the structure. delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] + content=new_piece, tool_calls=[] ), - ) + ), ) - def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Generator: """ Invoke result generator :param result: result generator :return: result generator """ - prompt_message = AssistantPromptMessage( - content="" - ) + prompt_message = AssistantPromptMessage(content="") usage = None system_fingerprint = None real_model = model @@ -404,7 +461,7 @@ if you are not sure about the structure. stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) prompt_message.content += chunk.delta.message.content @@ -424,7 +481,7 @@ if you are not sure about the structure. prompt_messages=prompt_messages, message=prompt_message, usage=usage if usage else LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint + system_fingerprint=system_fingerprint, ), credentials=credentials, prompt_messages=prompt_messages, @@ -433,15 +490,21 @@ if you are not sure about the structure. stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) @abstractmethod - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -456,10 +519,15 @@ if you are not sure about the structure. :return: full response or stream response chunk generator result """ raise NotImplementedError - + @abstractmethod - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -489,7 +557,9 @@ if you are not sure about the structure. for word in result.message.content: assistant_prompt_message = AssistantPromptMessage( content=word, - tool_calls=tool_calls if index == (len(result.message.content) - 1) else [] + tool_calls=tool_calls + if index == (len(result.message.content) - 1) + else [], ) yield LLMResultChunk( @@ -499,7 +569,7 @@ if you are not sure about the structure. delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -531,11 +601,15 @@ if you are not sure about the structure. mode = LLMMode.CHAT if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): - mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]) + mode = LLMMode.value_of( + model_schema.model_properties[ModelPropertyKey.MODE] + ) return mode - def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -558,7 +632,7 @@ if you are not sure about the structure. model=model, credentials=credentials, price_type=PriceType.OUTPUT, - tokens=completion_tokens + tokens=completion_tokens, ) # transform usage @@ -572,18 +646,26 @@ if you are not sure about the structure. completion_price_unit=completion_price_info.unit, completion_price=completion_price_info.total_amount, total_tokens=prompt_tokens + completion_tokens, - total_price=prompt_price_info.total_amount + completion_price_info.total_amount, + total_price=prompt_price_info.total_amount + + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> None: + def _trigger_before_invoke_callbacks( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> None: """ Trigger before invoke callbacks @@ -609,19 +691,29 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: raise e else: - logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}") + logger.warning( + f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}" + ) - def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> None: + def _trigger_new_chunk_callbacks( + self, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> None: """ Trigger new chunk callbacks @@ -648,19 +740,29 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: raise e else: - logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}") + logger.warning( + f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}" + ) - def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> None: + def _trigger_after_invoke_callbacks( + self, + model: str, + result: LLMResult, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> None: """ Trigger after invoke callbacks @@ -688,19 +790,29 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: raise e else: - logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}") + logger.warning( + f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}" + ) - def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: list[Callback] = None) -> None: + def _trigger_invoke_error_callbacks( + self, + model: str, + ex: Exception, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> None: """ Trigger invoke error callbacks @@ -728,15 +840,19 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: raise e else: - logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}") + logger.warning( + f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}" + ) - def _validate_and_filter_model_parameters(self, model: str, model_parameters: dict, credentials: dict) -> dict: + def _validate_and_filter_model_parameters( + self, model: str, model_parameters: dict, credentials: dict + ) -> dict: """ Validate model parameters @@ -753,16 +869,23 @@ if you are not sure about the structure. parameter_name = parameter_rule.name parameter_value = model_parameters.get(parameter_name) if parameter_value is None: - if parameter_rule.use_template and parameter_rule.use_template in model_parameters: + if ( + parameter_rule.use_template + and parameter_rule.use_template in model_parameters + ): # if parameter value is None, use template value variable name instead parameter_value = model_parameters[parameter_rule.use_template] else: if parameter_rule.required: if parameter_rule.default is not None: - filtered_model_parameters[parameter_name] = parameter_rule.default + filtered_model_parameters[ + parameter_name + ] = parameter_rule.default continue else: - raise ValueError(f"Model Parameter {parameter_name} is required.") + raise ValueError( + f"Model Parameter {parameter_name} is required." + ) else: continue @@ -772,47 +895,81 @@ if you are not sure about the structure. raise ValueError(f"Model Parameter {parameter_name} should be int.") # validate parameter value range - if parameter_rule.min is not None and parameter_value < parameter_rule.min: + if ( + parameter_rule.min is not None + and parameter_value < parameter_rule.min + ): raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) - if parameter_rule.max is not None and parameter_value > parameter_rule.max: + if ( + parameter_rule.max is not None + and parameter_value > parameter_rule.max + ): raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.FLOAT: if not isinstance(parameter_value, float | int): - raise ValueError(f"Model Parameter {parameter_name} should be float.") + raise ValueError( + f"Model Parameter {parameter_name} should be float." + ) # validate parameter value precision if parameter_rule.precision is not None: if parameter_rule.precision == 0: if parameter_value != int(parameter_value): - raise ValueError(f"Model Parameter {parameter_name} should be int.") - else: - if parameter_value != round(parameter_value, parameter_rule.precision): raise ValueError( - f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.") + f"Model Parameter {parameter_name} should be int." + ) + else: + if parameter_value != round( + parameter_value, parameter_rule.precision + ): + raise ValueError( + f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places." + ) # validate parameter value range - if parameter_rule.min is not None and parameter_value < parameter_rule.min: + if ( + parameter_rule.min is not None + and parameter_value < parameter_rule.min + ): raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) - if parameter_rule.max is not None and parameter_value > parameter_rule.max: + if ( + parameter_rule.max is not None + and parameter_value > parameter_rule.max + ): raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.BOOLEAN: if not isinstance(parameter_value, bool): - raise ValueError(f"Model Parameter {parameter_name} should be bool.") + raise ValueError( + f"Model Parameter {parameter_name} should be bool." + ) elif parameter_rule.type == ParameterType.STRING: if not isinstance(parameter_value, str): - raise ValueError(f"Model Parameter {parameter_name} should be string.") + raise ValueError( + f"Model Parameter {parameter_name} should be string." + ) # validate options - if parameter_rule.options and parameter_value not in parameter_rule.options: - raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") + if ( + parameter_rule.options + and parameter_value not in parameter_rule.options + ): + raise ValueError( + f"Model Parameter {parameter_name} should be one of {parameter_rule.options}." + ) else: - raise ValueError(f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported.") + raise ValueError( + f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported." + ) filtered_model_parameters[parameter_name] = parameter_value diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/model_provider.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/model_provider.py index 9814ac06..48f4a942 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/model_provider.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/model_provider.py @@ -4,7 +4,10 @@ from abc import ABC, abstractmethod import yaml -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelType, +) from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel @@ -36,24 +39,26 @@ class ModelProvider(ABC): return self.provider_schema # get dirname of the current path - provider_name = self.__class__.__module__.split('.')[-1] + provider_name = self.__class__.__module__.split(".")[-1] # get the path of the model_provider classes base_path = os.path.abspath(__file__) - current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) + current_path = os.path.join( + os.path.dirname(os.path.dirname(base_path)), provider_name + ) # read provider schema from yaml file - yaml_path = os.path.join(current_path, f'{provider_name}.yaml') + yaml_path = os.path.join(current_path, f"{provider_name}.yaml") yaml_data = {} if os.path.exists(yaml_path): - with open(yaml_path, encoding='utf-8') as f: + with open(yaml_path, encoding="utf-8") as f: yaml_data = yaml.safe_load(f) try: # yaml_data to entity provider_schema = ProviderEntity(**yaml_data) except Exception as e: - raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}') + raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}") # cache schema self.provider_schema = provider_schema @@ -88,37 +93,52 @@ class ModelProvider(ABC): :return: """ # get dirname of the current path - provider_name = self.__class__.__module__.split('.')[-1] + provider_name = self.__class__.__module__.split(".")[-1] if f"{provider_name}.{model_type.value}" in self.model_instance_map: return self.model_instance_map[f"{provider_name}.{model_type.value}"] # get the path of the model type classes base_path = os.path.abspath(__file__) - model_type_name = model_type.value.replace('-', '_') - model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name) - model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py') + model_type_name = model_type.value.replace("-", "_") + model_type_path = os.path.join( + os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name + ) + model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py") if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path): - raise Exception(f'Invalid model type {model_type} for provider {provider_name}') + raise Exception( + f"Invalid model type {model_type} for provider {provider_name}" + ) # Dynamic loading {model_type_name}.py file and find the subclass of AIModel - parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) - spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path) + parent_module = ".".join(self.__class__.__module__.split(".")[:-1]) + spec = importlib.util.spec_from_file_location( + f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path + ) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) model_class = None for name, obj in vars(mod).items(): - if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__ - and obj != AIModel and obj.__module__ == mod.__name__): + if ( + isinstance(obj, type) + and issubclass(obj, AIModel) + and not obj.__abstractmethods__ + and obj != AIModel + and obj.__module__ == mod.__name__ + ): model_class = obj break if not model_class: - raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}') + raise Exception( + f"Missing AIModel Class for model type {model_type} in {model_type_py_path}" + ) model_instance_map = model_class() - self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map + self.model_instance_map[ + f"{provider_name}.{model_type.value}" + ] = model_instance_map return model_instance_map diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/moderation_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/moderation_model.py index 21cc7e1c..e6b2ca72 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/moderation_model.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/moderation_model.py @@ -10,11 +10,12 @@ class ModerationModel(AIModel): """ Model class for moderation model. """ + model_type: ModelType = ModelType.MODERATION - def invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def invoke( + self, model: str, credentials: dict, text: str, user: Optional[str] = None + ) -> bool: """ Invoke moderation model @@ -32,9 +33,9 @@ class ModerationModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke( + self, model: str, credentials: dict, text: str, user: Optional[str] = None + ) -> bool: """ Invoke large language model @@ -45,4 +46,3 @@ class ModerationModel(AIModel): :return: false if text is safe, true otherwise """ raise NotImplementedError - diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/rerank_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/rerank_model.py index e38cc837..6fdafde9 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/rerank_model.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/rerank_model.py @@ -11,12 +11,19 @@ class RerankModel(AIModel): """ Base Model class for rerank model. """ + model_type: ModelType = ModelType.RERANK - def invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -32,15 +39,23 @@ class RerankModel(AIModel): self.started_at = time.perf_counter() try: - return self._invoke(model, credentials, query, docs, score_threshold, top_n, user) + return self._invoke( + model, credentials, query, docs, score_threshold, top_n, user + ) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/speech2text_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/speech2text_model.py index eaed8282..82a28943 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/speech2text_model.py @@ -10,11 +10,12 @@ class Speech2TextModel(AIModel): """ Model class for speech2text model. """ + model_type: ModelType = ModelType.SPEECH2TEXT - def invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke( + self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None + ) -> str: """ Invoke large language model @@ -30,9 +31,9 @@ class Speech2TextModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke( + self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None + ) -> str: """ Invoke large language model @@ -54,4 +55,4 @@ class Speech2TextModel(AIModel): current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the path to the audio file - return os.path.join(current_dir, 'audio.mp3') + return os.path.join(current_dir, "audio.mp3") diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/text2img_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/text2img_model.py index 4139cfc1..058c910e 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/text2img_model.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/text2img_model.py @@ -9,11 +9,17 @@ class Text2ImageModel(AIModel): """ Model class for text2img model. """ + model_type: ModelType = ModelType.TEXT2IMG - def invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def invoke( + self, + model: str, + credentials: dict, + prompt: str, + model_parameters: dict, + user: Optional[str] = None, + ) -> list[IO[bytes]]: """ Invoke Text2Image model @@ -31,9 +37,14 @@ class Text2ImageModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def _invoke( + self, + model: str, + credentials: dict, + prompt: str, + model_parameters: dict, + user: Optional[str] = None, + ) -> list[IO[bytes]]: """ Invoke Text2Image model diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/text_embedding_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/text_embedding_model.py index 20ce474b..d46b412e 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -2,8 +2,13 @@ import time from abc import abstractmethod from typing import Optional -from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from model_providers.core.model_runtime.entities.model_entities import ( + ModelPropertyKey, + ModelType, +) +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + TextEmbeddingResult, +) from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel @@ -11,11 +16,16 @@ class TextEmbeddingModel(AIModel): """ Model class for text embedding model. """ + model_type: ModelType = ModelType.TEXT_EMBEDDING - def invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke large language model @@ -33,9 +43,13 @@ class TextEmbeddingModel(AIModel): raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke large language model @@ -69,7 +83,10 @@ class TextEmbeddingModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: + if ( + model_schema + and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties + ): return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] return 1000 @@ -84,7 +101,10 @@ class TextEmbeddingModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: + if ( + model_schema + and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + ): return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] return 1 diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6059b3f5..f48d9a9d 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -7,27 +7,30 @@ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer _tokenizer = None _lock = Lock() + class GPT2Tokenizer: @staticmethod def _get_num_tokens_by_gpt2(text: str) -> int: """ - use gpt2 tokenizer to get num tokens + use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() tokens = _tokenizer.encode(text, verbose=False) return len(tokens) - + @staticmethod def get_num_tokens(text: str) -> int: return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - + @staticmethod def get_encoder() -> Any: global _tokenizer, _lock with _lock: if _tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'gpt2') - _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) + gpt2_tokenizer_path = join(dirname(base_path), "gpt2") + _tokenizer = TransformerGPT2Tokenizer.from_pretrained( + gpt2_tokenizer_path + ) - return _tokenizer \ No newline at end of file + return _tokenizer diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/tts_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/tts_model.py index b99f3c79..47ebbd56 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__base/tts_model.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/tts_model.py @@ -4,7 +4,10 @@ import uuid from abc import abstractmethod from typing import Optional -from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from model_providers.core.model_runtime.entities.model_entities import ( + ModelPropertyKey, + ModelType, +) from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel @@ -13,10 +16,19 @@ class TTSModel(AIModel): """ Model class for ttstext model. """ + model_type: ModelType = ModelType.TTS - def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, - user: Optional[str] = None): + def invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + streaming: bool, + user: Optional[str] = None, + ): """ Invoke large language model @@ -31,14 +43,29 @@ class TTSModel(AIModel): """ try: self._is_ffmpeg_installed() - return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, - content_text=content_text, voice=voice, tenant_id=tenant_id) + return self._invoke( + model=model, + credentials=credentials, + user=user, + streaming=streaming, + content_text=content_text, + voice=voice, + tenant_id=tenant_id, + ) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, - user: Optional[str] = None): + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + streaming: bool, + user: Optional[str] = None, + ): """ Invoke large language model @@ -53,7 +80,9 @@ class TTSModel(AIModel): """ raise NotImplementedError - def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + def get_tts_model_voices( + self, model: str, credentials: dict, language: Optional[str] = None + ) -> list: """ Get voice for given tts model voices @@ -67,9 +96,13 @@ class TTSModel(AIModel): if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: voices = model_schema.model_properties[ModelPropertyKey.VOICES] if language: - return [{'name': d['name'], 'value': d['mode']} for d in voices if language and language in d.get('language')] + return [ + {"name": d["name"], "value": d["mode"]} + for d in voices + if language and language in d.get("language") + ] else: - return [{'name': d['name'], 'value': d['mode']} for d in voices] + return [{"name": d["name"], "value": d["mode"]} for d in voices] def _get_model_default_voice(self, model: str, credentials: dict) -> any: """ @@ -81,7 +114,10 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties: + if ( + model_schema + and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties + ): return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE] def _get_model_audio_type(self, model: str, credentials: dict) -> str: @@ -94,7 +130,10 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties: + if ( + model_schema + and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties + ): return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] def _get_model_word_limit(self, model: str, credentials: dict) -> int: @@ -104,7 +143,10 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties: + if ( + model_schema + and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties + ): return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] def _get_model_workers_limit(self, model: str, credentials: dict) -> int: @@ -114,13 +156,16 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties: + if ( + model_schema + and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties + ): return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] @staticmethod def _split_text_into_sentences(text: str, limit: int, delimiters=None): if delimiters is None: - delimiters = set('。!?;\n') + delimiters = set("。!?;\n") buf = [] word_count = 0 @@ -128,7 +173,7 @@ class TTSModel(AIModel): buf.append(char) if char in delimiters: if word_count >= limit: - yield ''.join(buf) + yield "".join(buf) buf = [] word_count = 0 else: @@ -137,7 +182,7 @@ class TTSModel(AIModel): word_count += 1 if buf: - yield ''.join(buf) + yield "".join(buf) @staticmethod def _is_ffmpeg_installed(): @@ -146,13 +191,17 @@ class TTSModel(AIModel): if "ffmpeg version" in output.decode("utf-8"): return True else: - raise InvokeBadRequestError("ffmpeg is not installed, " - "details: https://docs.dify.ai/getting-started/install-self-hosted" - "/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech") + raise InvokeBadRequestError( + "ffmpeg is not installed, " + "details: https://docs.dify.ai/getting-started/install-self-hosted" + "/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech" + ) except Exception: - raise InvokeBadRequestError("ffmpeg is not installed, " - "details: https://docs.dify.ai/getting-started/install-self-hosted" - "/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech") + raise InvokeBadRequestError( + "ffmpeg is not installed, " + "details: https://docs.dify.ai/getting-started/install-self-hosted" + "/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech" + ) # Todo: To improve the streaming function @staticmethod @@ -160,6 +209,6 @@ class TTSModel(AIModel): hash_object = hashlib.sha256(file_content.encode()) hex_digest = hash_object.hexdigest() - namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31') + namespace_uuid = uuid.UUID("a5da6ef9-b303-596f-8e88-bf8fa40f4b31") unique_uuid = uuid.uuid5(namespace_uuid, hex_digest) return str(unique_uuid) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/__init__.py index f3578b0f..d3b363f1 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/__init__.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/__init__.py @@ -1,3 +1,5 @@ -from model_providers.core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from model_providers.core.model_runtime.model_providers.model_provider_factory import ( + ModelProviderFactory, +) model_provider_factory = ModelProviderFactory() diff --git a/model-providers/model_providers/core/model_runtime/model_providers/anthropic/anthropic.py b/model-providers/model_providers/core/model_runtime/model_providers/anthropic/anthropic.py index 3f5b9507..06b7c611 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/anthropic/anthropic.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/anthropic/anthropic.py @@ -1,8 +1,12 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) @@ -21,11 +25,12 @@ class AnthropicProvider(ModelProvider): # Use `claude-instant-1` model for validate, model_instance.validate_credentials( - model='claude-instant-1.2', - credentials=credentials + model="claude-instant-1.2", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/anthropic/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/anthropic/llm/llm.py index 6eda2ffd..032c2757 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -18,7 +18,11 @@ from anthropic.types import ( from httpx import Timeout from model_providers.core.model_runtime.callbacks.base_callback import Callback -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -37,8 +41,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure @@ -51,11 +59,17 @@ if you are not sure about the structure. class AnthropicLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -70,11 +84,20 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + return self._chat_generate( + model, credentials, prompt_messages, model_parameters, stop, stream, user + ) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -91,23 +114,27 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): credentials_kwargs = self._to_credential_kwargs(credentials) # transform model parameters from completion api of anthropic to chat api - if 'max_tokens_to_sample' in model_parameters: - model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample') + if "max_tokens_to_sample" in model_parameters: + model_parameters["max_tokens"] = model_parameters.pop( + "max_tokens_to_sample" + ) # init model client client = Anthropic(**credentials_kwargs) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop if user: - extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) + extra_model_kwargs["metadata"] = completion_create_params.Metadata( + user_id=user + ) system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system # chat model response = client.messages.create( @@ -115,22 +142,37 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): messages=prompt_message_dicts, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: - return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) + return self._handle_chat_generate_stream_response( + model, credentials, response, prompt_messages + ) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages) + return self._handle_chat_generate_response( + model, credentials, response, prompt_messages + ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format']: + if ( + "response_format" in model_parameters + and model_parameters["response_format"] + ): stop = stop or [] # chat model self._transform_chat_json_prompts( @@ -142,17 +184,33 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") - return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return self._invoke( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -162,25 +220,40 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): stop.append("\n```") # check if there is a system message - if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + if len(prompt_messages) > 0 and isinstance( + prompt_messages[0], SystemPromptMessage + ): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", prompt_messages[0].content + ).replace("{{block}}", response_format) + ) + prompt_messages.append( + AssistantPromptMessage(content=f"\n```{response_format}") ) - prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) - prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", + f"Please output a valid {response_format} object.", + ).replace("{{block}}", response_format) + ), + ) + prompt_messages.append( + AssistantPromptMessage(content=f"\n```{response_format}") + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -214,13 +287,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): "temperature": 0, "max_tokens": 20, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: Message, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm chat response @@ -243,24 +321,32 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): else: # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response response = LLMResult( model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, - usage=usage + usage=usage, ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -269,7 +355,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -284,28 +370,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): finish_reason = chunk.delta.stop_reason elif isinstance(chunk, MessageStopEvent): # transform usage - usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) + usage = self._calc_response_usage( + model, credentials, input_tokens, output_tokens + ) yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='' - ), + message=AssistantPromptMessage(content=""), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text if chunk.delta.text else "" full_assistant_content += chunk_text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=chunk_text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_text) index = chunk.index @@ -315,7 +399,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=chunk.index, message=assistant_prompt_message, - ) + ), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -326,18 +410,22 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "api_key": credentials['anthropic_api_key'], + "api_key": credentials["anthropic_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } - if 'anthropic_api_url' in credentials and credentials['anthropic_api_url']: - credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/') - credentials_kwargs['base_url'] = credentials['anthropic_api_url'] + if "anthropic_api_url" in credentials and credentials["anthropic_api_url"]: + credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip( + "/" + ) + credentials_kwargs["base_url"] = credentials["anthropic_api_url"] return credentials_kwargs - def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: + def _convert_prompt_messages( + self, prompt_messages: list[PromptMessage] + ) -> tuple[str, list[dict]]: """ Convert prompt messages to dict list and system """ @@ -348,7 +436,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): if isinstance(message, SystemPromptMessage): system += message.content + ("\n" if not system else "") else: - prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) + prompt_message_dicts.append( + self._convert_prompt_message_to_dict(message) + ) return system, prompt_message_dicts @@ -364,38 +454,57 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): sub_messages = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) sub_message_dict = { "type": "text", - "text": message_content.data + "text": message_content.data, } sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) + message_content = cast( + ImagePromptMessageContent, message_content + ) if not message_content.data.startswith("data:"): # fetch image data from url try: - image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) - base64_data = base64.b64encode(image_content).decode('utf-8') + image_content = requests.get( + message_content.data + ).content + mime_type, _ = mimetypes.guess_type( + message_content.data + ) + base64_data = base64.b64encode(image_content).decode( + "utf-8" + ) except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + raise ValueError( + f"Failed to fetch image data from url {message_content.data}, {ex}" + ) else: data_split = message_content.data.split(";base64,") mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + if mime_type not in [ + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + ]: + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", "source": { "type": "base64", "media_type": mime_type, - "data": base64_data - } + "data": base64_data, + }, } sub_messages.append(sub_message_dict) @@ -450,7 +559,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str: + def _convert_messages_to_prompt_anthropic( + self, messages: list[PromptMessage] + ) -> str: """ Format a list of messages into a full prompt for the Anthropic model @@ -458,15 +569,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) text = "".join( - self._convert_one_message_to_text(message) - for message in messages + self._convert_one_message_to_text(message) for message in messages ) # trim off the trailing ' ' that might come from the "Assistant: " @@ -485,22 +595,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return { InvokeConnectionError: [ anthropic.APIConnectionError, - anthropic.APITimeoutError - ], - InvokeServerUnavailableError: [ - anthropic.InternalServerError - ], - InvokeRateLimitError: [ - anthropic.RateLimitError + anthropic.APITimeoutError, ], + InvokeServerUnavailableError: [anthropic.InternalServerError], + InvokeRateLimitError: [anthropic.RateLimitError], InvokeAuthorizationError: [ anthropic.AuthenticationError, - anthropic.PermissionDeniedError + anthropic.PermissionDeniedError, ], InvokeBadRequestError: [ anthropic.BadRequestError, anthropic.NotFoundError, anthropic.UnprocessableEntityError, - anthropic.APIError - ] + anthropic.APIError, + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_common.py index b480efc6..6ae57a15 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_common.py @@ -9,16 +9,18 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION +from model_providers.core.model_runtime.model_providers.azure_openai._constant import ( + AZURE_OPENAI_API_VERSION, +) class _CommonAzureOpenAI: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION) + api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION) credentials_kwargs = { - "api_key": credentials['openai_api_key'], - "azure_endpoint": credentials['openai_api_base'], + "api_key": credentials["openai_api_key"], + "azure_endpoint": credentials["openai_api_base"], "api_version": api_version, "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, @@ -29,24 +31,17 @@ class _CommonAzureOpenAI: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - openai.APIConnectionError, - openai.APITimeoutError - ], - InvokeServerUnavailableError: [ - openai.InternalServerError - ], - InvokeRateLimitError: [ - openai.RateLimitError - ], + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], InvokeAuthorizationError: [ openai.AuthenticationError, - openai.PermissionDeniedError + openai.PermissionDeniedError, ], InvokeBadRequestError: [ openai.BadRequestError, openai.NotFoundError, openai.UnprocessableEntityError, - openai.APIError - ] + openai.APIError, + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py index c73e8fd8..b4a4cbba 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_constant.py @@ -14,11 +14,12 @@ from model_providers.core.model_runtime.entities.model_entities import ( PriceConfig, ) -AZURE_OPENAI_API_VERSION = '2024-02-15-preview' +AZURE_OPENAI_API_VERSION = "2024-02-15-preview" + def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: rule = ParameterRule( - name='max_tokens', + name="max_tokens", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS], ) rule.default = default @@ -34,11 +35,11 @@ class AzureBaseModel(BaseModel): LLM_BASE_MODELS = [ AzureBaseModel( - base_model_name='gpt-35-turbo', + base_model_name="gpt-35-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -53,37 +54,37 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=4096) + _get_max_tokens(default=512, min_val=1, max_val=4096), ], pricing=PriceConfig( input=0.001, output=0.002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-16k', + base_model_name="gpt-35-turbo-16k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -98,37 +99,37 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=16385) + _get_max_tokens(default=512, min_val=1, max_val=16385), ], pricing=PriceConfig( input=0.003, output=0.004, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4', + base_model_name="gpt-4", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -143,32 +144,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=8192), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -176,34 +174,31 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", + en_US="specifying the format that the model must output", ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.03, output=0.06, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-32k', + base_model_name="gpt-4-32k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -218,32 +213,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=32768), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -251,34 +243,31 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", + en_US="specifying the format that the model must output", ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.06, output=0.12, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-1106-preview', + base_model_name="gpt-4-1106-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -293,32 +282,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -326,39 +312,34 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", + en_US="specifying the format that the model must output", ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-vision-preview', + base_model_name="gpt-4-vision-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, - features=[ - ModelFeature.VISION - ], + features=[ModelFeature.VISION], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: LLMMode.CHAT.value, @@ -366,32 +347,29 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.", ), required=False, precision=2, @@ -399,34 +377,31 @@ LLM_BASE_MODELS = [ max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", + en_US="specifying the format that the model must output", ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-instruct', + base_model_name="gpt-35-turbo-instruct", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -436,19 +411,19 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -457,16 +432,16 @@ LLM_BASE_MODELS = [ input=0.0015, output=0.002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-davinci-003', + base_model_name="text-davinci-003", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -476,19 +451,19 @@ LLM_BASE_MODELS = [ }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -497,20 +472,18 @@ LLM_BASE_MODELS = [ input=0.02, output=0.02, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] EMBEDDING_BASE_MODELS = [ AzureBaseModel( - base_model_name='text-embedding-ada-002', + base_model_name="text-embedding-ada-002", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -520,17 +493,15 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.0001, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-small', + base_model_name="text-embedding-3-small", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -540,17 +511,15 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.00002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-large', + base_model_name="text-embedding-3-large", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -560,135 +529,237 @@ EMBEDDING_BASE_MODELS = [ pricing=PriceConfig( input=0.00013, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] SPEECH2TEXT_BASE_MODELS = [ AzureBaseModel( - base_model_name='whisper-1', + base_model_name="whisper-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={ ModelPropertyKey.FILE_UPLOAD_LIMIT: 25, - ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm' - } - ) + ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: "flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm", + }, + ), ) ] TTS_BASE_MODELS = [ AzureBaseModel( - base_model_name='tts-1', + base_model_name="tts-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='tts-1-hd', + base_model_name="tts-1-hd", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": [ + "zh-Hans", + "en-US", + "de-DE", + "fr-FR", + "es-ES", + "it-IT", + "th-TH", + "id-ID", + ], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.03, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/azure_openai.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/azure_openai.py index 7e3e3fb0..2c14a6bd 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/azure_openai.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/azure_openai.py @@ -1,11 +1,12 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class AzureOpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/llm/llm.py index 56fbb72c..c712fd06 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -6,11 +6,23 @@ from typing import Optional, Union, cast import tiktoken from openai import AzureOpenAI, Stream from openai.types import Completion -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, +) from openai.types.chat.chat_completion_message import FunctionCall -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -22,26 +34,47 @@ from model_providers.core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from model_providers.core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelPropertyKey, +) +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.azure_openai._common import ( + _CommonAzureOpenAI, +) +from model_providers.core.model_runtime.model_providers.azure_openai._constant import ( + LLM_BASE_MODELS, + AzureBaseModel, +) logger = logging.getLogger(__name__) class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + ai_model_entity = self._get_ai_model_entity( + credentials.get("base_model_name"), model + ) - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - - ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) - - if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + if ( + ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) + == LLMMode.CHAT.value + ): # chat model return self._chat_generate( model=model, @@ -51,7 +84,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -62,14 +95,19 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: - - model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get( - ModelPropertyKey.MODE) + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + model_mode = self._get_ai_model_entity( + credentials.get("base_model_name"), model + ).entity.model_properties.get(ModelPropertyKey.MODE) if model_mode == LLMMode.CHAT.value: # chat model @@ -79,27 +117,36 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return self._num_tokens_from_string(credentials, prompt_messages[0].content) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError( + "Azure OpenAI API Base Endpoint is required" + ) - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) + ai_model_entity = self._get_ai_model_entity( + credentials.get("base_model_name"), model + ) if not ai_model_entity: - raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') + raise CredentialsValidateFailedError( + f'Base Model Name {credentials["base_model_name"]} is invalid' + ) try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + if ( + ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) + == LLMMode.CHAT.value + ): # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -108,7 +155,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -117,23 +164,33 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: + ai_model_entity = self._get_ai_model_entity( + credentials.get("base_model_name"), model + ) return ai_model_entity.entity if ai_model_entity else None - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # text completion model response = client.completions.create( @@ -141,22 +198,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, response, prompt_messages + ) - return self._handle_generate_response(model, credentials, response, prompt_messages) + return self._handle_generate_response( + model, credentials, response, prompt_messages + ) - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + response: Completion, + prompt_messages: list[PromptMessage], + ) -> LLMResult: assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -165,11 +229,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): completion_tokens = response.usage.completion_tokens else: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(credentials, assistant_text) + prompt_tokens = self._num_tokens_from_string( + credentials, prompt_messages[0].content + ) + completion_tokens = self._num_tokens_from_string( + credentials, assistant_text + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response result = LLMResult( @@ -182,23 +252,26 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: - full_text = '' + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[Completion], + prompt_messages: list[PromptMessage], + ) -> Generator: + full_text = "" for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text if delta.text else "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -210,11 +283,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): completion_tokens = chunk.usage.completion_tokens else: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(credentials, full_text) + prompt_tokens = self._num_tokens_from_string( + credentials, prompt_messages[0].content + ) + completion_tokens = self._num_tokens_from_string( + credentials, full_text + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=chunk.model, @@ -224,8 +303,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( @@ -235,14 +314,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) response_format = model_parameters.get("response_format") @@ -258,17 +343,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): if tools: # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - extra_model_kwargs['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in tools + ] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # chat model response = client.chat.completions.create( @@ -280,27 +368,36 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ) if stream: - return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) + return self._handle_chat_generate_stream_response( + model, credentials, response, prompt_messages, tools + ) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + return self._handle_chat_generate_response( + model, credentials, response, prompt_messages, tools + ) + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: assistant_message = response.choices[0].message # assistant_message_tool_calls = assistant_message.tool_calls assistant_message_function_call = assistant_message.function_call # extract tool calls from response # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - function_call = self._extract_response_function_call(assistant_message_function_call) + function_call = self._extract_response_function_call( + assistant_message_function_call + ) tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls + content=assistant_message.content, tool_calls=tool_calls ) # calculate num tokens @@ -310,11 +407,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): completion_tokens = response.usage.completion_tokens else: # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message]) + prompt_tokens = self._num_tokens_from_messages( + credentials, prompt_messages, tools + ) + completion_tokens = self._num_tokens_from_messages( + credentials, [assistant_prompt_message] + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response response = LLMResult( @@ -327,24 +430,31 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: index = 0 - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None real_model = model system_fingerprint = None - completion = '' + completion = "" for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + delta.finish_reason is None + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -355,36 +465,44 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # handle process of stream function call if assistant_message_function_call: # message has not ended ever - delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments + delta_assistant_message_function_call_storage.arguments += ( + assistant_message_function_call.arguments + ) continue else: # message has ended - assistant_message_function_call = delta_assistant_message_function_call_storage + assistant_message_function_call = ( + delta_assistant_message_function_call_storage + ) delta_assistant_message_function_call_storage = None else: if assistant_message_function_call: # start of stream function call - delta_assistant_message_function_call_storage = assistant_message_function_call + delta_assistant_message_function_call_storage = ( + assistant_message_function_call + ) if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" continue # extract tool calls from response # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - function_call = self._extract_response_function_call(assistant_message_function_call) + function_call = self._extract_response_function_call( + assistant_message_function_call + ) tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", + tool_calls=tool_calls, ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" real_model = chunk.model system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content if delta.delta.content else '' + completion += delta.delta.content if delta.delta.content else "" yield LLMResultChunk( model=real_model, @@ -393,21 +511,25 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 0 # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - - full_assistant_prompt_message = AssistantPromptMessage( - content=completion + prompt_tokens = self._num_tokens_from_messages( + credentials, prompt_messages, tools + ) + + full_assistant_prompt_message = AssistantPromptMessage(content=completion) + completion_tokens = self._num_tokens_from_messages( + credentials, [full_assistant_prompt_message] ) - completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=real_model, @@ -415,55 +537,52 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): system_fingerprint=system_fingerprint, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), - finish_reason='stop', - usage=usage - ) + message=AssistantPromptMessage(content=""), + finish_reason="stop", + usage=usage, + ), ) @staticmethod - def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: - + def _extract_response_tool_calls( + response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall], + ) -> list[AssistantPromptMessage.ToolCall]: tool_calls = [] if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + arguments=response_tool_call.function.arguments, ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call.id, type=response_tool_call.type, - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls @staticmethod - def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: - + def _extract_response_function_call( + response_function_call: FunctionCall | ChoiceDeltaFunctionCall, + ) -> AssistantPromptMessage.ToolCall: tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_function_call.name, - arguments=response_function_call.arguments + arguments=response_function_call.arguments, ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @staticmethod def _convert_prompt_message_to_dict(message: PromptMessage) -> dict: - if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): @@ -472,20 +591,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): sub_messages = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) sub_message_dict = { "type": "text", - "text": message_content.data + "text": message_content.data, } sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) + message_content = cast( + ImagePromptMessageContent, message_content + ) sub_message_dict = { "type": "image_url", "image_url": { "url": message_content.data, - "detail": message_content.detail.value - } + "detail": message_content.detail.value, + }, } sub_messages.append(sub_message_dict) @@ -514,7 +637,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): message_dict = { "role": "function", "content": message.content, - "name": message.tool_call_id + "name": message.tool_call_id, } else: raise ValueError(f"Got unknown type {message}") @@ -524,10 +647,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, credentials: dict, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, + credentials: dict, + text: str, + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: try: - encoding = tiktoken.encoding_for_model(credentials['base_model_name']) + encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") @@ -538,13 +665,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, + credentials: dict, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - model = credentials['base_model_name'] + model = credentials["base_model_name"] try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -578,10 +709,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -611,41 +742,42 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens @staticmethod - def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: - + def _num_tokens_for_tools( + encoding: tiktoken.Encoding, tools: list[PromptMessageTool] + ) -> int: num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) + num_tokens += len(encoding.encode("type")) num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index 227edbec..2e4c8323 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -4,10 +4,19 @@ from typing import IO, Optional from openai import AzureOpenAI from model_providers.core.model_runtime.entities.model_entities import AIModelEntity -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from model_providers.core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.speech2text_model import ( + Speech2TextModel, +) +from model_providers.core.model_runtime.model_providers.azure_openai._common import ( + _CommonAzureOpenAI, +) +from model_providers.core.model_runtime.model_providers.azure_openai._constant import ( + SPEECH2TEXT_BASE_MODELS, + AzureBaseModel, +) class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): @@ -15,9 +24,9 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke( + self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None + ) -> str: """ Invoke speech2text model @@ -40,12 +49,14 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: + def _speech2text_invoke( + self, model: str, credentials: dict, file: IO[bytes] + ) -> str: """ Invoke speech2text model @@ -64,11 +75,14 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): return response.text - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: + ai_model_entity = self._get_ai_model_entity( + credentials["base_model_name"], model + ) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index 8f0420b1..17e442fe 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -7,28 +7,46 @@ import numpy as np import tiktoken from openai import AzureOpenAI -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from model_providers.core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_BASE_MODELS, AzureBaseModel +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + PriceType, +) +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.azure_openai._common import ( + _CommonAzureOpenAI, +) +from model_providers.core.model_runtime.model_providers.azure_openai._constant import ( + EMBEDDING_BASE_MODELS, + AzureBaseModel, +) class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: - base_model_name = credentials['base_model_name'] + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: + base_model_name = credentials["base_model_name"] credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -44,11 +62,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -58,8 +74,8 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): embeddings_batch, embedding_used_tokens = self._embedding_invoke( model=model, client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + texts=tokens[i : i + max_chunks], + extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens @@ -78,7 +94,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): model=model, client=client, texts="", - extra_model_kwargs=extra_model_kwargs + extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens @@ -89,15 +105,11 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): # calc usage usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens + model=model, credentials=credentials, tokens=used_tokens ) return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=base_model_name + embeddings=embeddings, usage=usage, model=base_model_name ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -105,7 +117,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): return 0 try: - enc = tiktoken.encoding_for_model(credentials['base_model_name']) + enc = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: enc = tiktoken.get_encoding("cl100k_base") @@ -118,57 +130,78 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): return total_num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError( + "Azure OpenAI API Base Endpoint is required" + ) - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - if not self._get_ai_model_entity(credentials['base_model_name'], model): - raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') + if not self._get_ai_model_entity(credentials["base_model_name"], model): + raise CredentialsValidateFailedError( + f'Base Model Name {credentials["base_model_name"]} is invalid' + ) try: credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} + model=model, client=client, texts=["ping"], extra_model_kwargs={} ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: + ai_model_entity = self._get_ai_model_entity( + credentials["base_model_name"], model + ) return ai_model_entity.entity @staticmethod - def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + model: str, + client: AzureOpenAI, + texts: Union[list[str], str], + extra_model_kwargs: dict, + ) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if ( + "encoding_format" in extra_model_kwargs + and extra_model_kwargs["encoding_format"] == "base64" + ): # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [ + list( + np.frombuffer(base64.b64decode(data.embedding), dtype="float32") + ) + for data in response.data + ], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: input_price_info = self.get_price( model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -179,7 +212,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/tts/tts.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/tts/tts.py index 4475b16e..9474e9b1 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -3,16 +3,24 @@ import copy from functools import reduce from io import BytesIO from typing import Optional + from fastapi.responses import StreamingResponse from openai import AzureOpenAI from pydub import AudioSegment from model_providers.core.model_runtime.entities.model_entities import AIModelEntity from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel -from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from model_providers.core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel +from model_providers.core.model_runtime.model_providers.azure_openai._common import ( + _CommonAzureOpenAI, +) +from model_providers.core.model_runtime.model_providers.azure_openai._constant import ( + TTS_BASE_MODELS, + AzureBaseModel, +) from model_providers.extensions.ext_storage import storage @@ -21,8 +29,16 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + streaming: bool, + user: Optional[str] = None, + ) -> any: """ _invoke text2speech model @@ -36,20 +52,34 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] + for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) if streaming: - return StreamingResponse(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - tenant_id=tenant_id, - voice=voice), media_type='text/event-stream') + return StreamingResponse( + self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text=content_text, + tenant_id=tenant_id, + voice=voice, + ), + media_type="text/event-stream", + ) else: - return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) + return self._tts_invoke( + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) - def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + def validate_credentials( + self, model: str, credentials: dict, user: Optional[str] = None + ) -> None: """ validate credentials text2speech model @@ -62,13 +92,15 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): self._tts_invoke( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse: + def _tts_invoke( + self, model: str, credentials: dict, content_text: str, voice: str + ) -> StreamingResponse: """ _tts_invoke text2speech model @@ -82,13 +114,25 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): word_limit = self._get_model_word_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials) try: - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list( + self._split_text_into_sentences(text=content_text, limit=word_limit) + ) audio_bytes_list = list() # Create a thread pool and map the function to the list of sentences - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice, - credentials=credentials) for sentence in sentences] + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + futures = [ + executor.submit( + self._process_sentence, + sentence=sentence, + model=model, + voice=voice, + credentials=credentials, + ) + for sentence in sentences + ] for future in futures: try: if future.result(): @@ -97,8 +141,11 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): raise InvokeBadRequestError(str(ex)) if len(audio_bytes_list) > 0: - audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in - audio_bytes_list if audio_bytes] + audio_segments = [ + AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) + for audio_bytes in audio_bytes_list + if audio_bytes + ] combined_segment = reduce(lambda x, y: x + y, audio_segments) buffer: BytesIO = BytesIO() combined_segment.export(buffer, format=audio_type) @@ -108,8 +155,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): raise InvokeBadRequestError(str(ex)) # Todo: To improve the streaming function - def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + ) -> any: """ _tts_invoke_streaming text2speech model @@ -122,24 +175,29 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): """ # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + if not voice or voice not in self.get_tts_model_voices( + model=model, credentials=credentials + ): voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) audio_type = self._get_model_audio_type(model, credentials) tts_file_id = self._get_file_name(content_text) - file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' + file_path = f"generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}" try: client = AzureOpenAI(**credentials_kwargs) - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list( + self._split_text_into_sentences(text=content_text, limit=word_limit) + ) for sentence in sentences: - response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + response = client.audio.speech.create( + model=model, voice=voice, input=sentence.strip() + ) # response.stream_to_file(file_path) storage.save(file_path, response.read()) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api @@ -152,12 +210,18 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + response = client.audio.speech.create( + model=model, voice=voice, input=sentence.strip() + ) if isinstance(response.read(), bytes): return response.read() - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: + ai_model_entity = self._get_ai_model_entity( + credentials["base_model_name"], model + ) return ai_model_entity.entity @staticmethod diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/baichuan.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/baichuan.py index 731b6efc..17ef6694 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/baichuan.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/baichuan.py @@ -1,11 +1,16 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) + class BaichuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -20,11 +25,12 @@ class BaichuanProvider(ModelProvider): # Use `baichuan2-turbo` model for validate, model_instance.validate_credentials( - model='baichuan2-turbo', - credentials=credentials + model="baichuan2-turbo", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 7549b2fb..371f74de 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -4,17 +4,20 @@ import re class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: - return len(re.findall(r'[\u4e00-\u9fa5]', text)) + return len(re.findall(r"[\u4e00-\u9fa5]", text)) @classmethod def count_english_vocabularies(cls, text: str) -> int: # remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc. - text = re.sub(r'[^a-zA-Z0-9\s]', '', text) + text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # count the number of words not characters return len(text.split()) - + @classmethod def _get_num_tokens(cls, text: str) -> int: # tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return) # https://platform.baichuan-ai.com/docs/text-Embedding - return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) \ No newline at end of file + return int( + cls.count_chinese_characters(text) + + cls.count_english_vocabularies(text) * 1.3 + ) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index a57fdd2b..9fcffccc 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -18,153 +18,188 @@ from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tu class BaichuanMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" # Baichuan does not have system message - _SYSTEM = 'system' + _SYSTEM = "system" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - - def __init__(self, content: str, role: str = 'user') -> None: + + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role + class BaichuanModel: api_key: str secret_key: str - def __init__(self, api_key: str, secret_key: str = '') -> None: + def __init__(self, api_key: str, secret_key: str = "") -> None: self.api_key = api_key self.secret_key = secret_key def _model_mapping(self, model: str) -> str: return { - 'baichuan2-turbo': 'Baichuan2-Turbo', - 'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k', - 'baichuan2-53b': 'Baichuan2-53B', + "baichuan2-turbo": "Baichuan2-Turbo", + "baichuan2-turbo-192k": "Baichuan2-Turbo-192k", + "baichuan2-53b": "Baichuan2-53B", }[model] def _handle_chat_generate_response(self, response) -> BaichuanMessage: - resp = response.json() - choices = resp.get('choices', []) - message = BaichuanMessage(content='', role='assistant') - for choice in choices: - message.content += choice['message']['content'] - message.role = choice['message']['role'] - if choice['finish_reason']: - message.stop_reason = choice['finish_reason'] + resp = response.json() + choices = resp.get("choices", []) + message = BaichuanMessage(content="", role="assistant") + for choice in choices: + message.content += choice["message"]["content"] + message.role = choice["message"]["role"] + if choice["finish_reason"]: + message.stop_reason = choice["finish_reason"] + + if "usage" in resp: + message.usage = { + "prompt_tokens": resp["usage"]["prompt_tokens"], + "completion_tokens": resp["usage"]["completion_tokens"], + "total_tokens": resp["usage"]["total_tokens"], + } + + return message - if 'usage' in resp: - message.usage = { - 'prompt_tokens': resp['usage']['prompt_tokens'], - 'completion_tokens': resp['usage']['completion_tokens'], - 'total_tokens': resp['usage']['total_tokens'], - } - - return message - def _handle_chat_stream_generate_response(self, response) -> Generator: for line in response.iter_lines(): if not line: continue - line = line.decode('utf-8') + line = line.decode("utf-8") # remove the first `data: ` prefix - if line.startswith('data:'): + if line.startswith("data:"): line = line[5:].strip() try: data = loads(line) except Exception as e: - if line.strip() == '[DONE]': + if line.strip() == "[DONE]": return - choices = data.get('choices', []) + choices = data.get("choices", []) # save stop reason temporarily - stop_reason = '' + stop_reason = "" for choice in choices: - if 'finish_reason' in choice and choice['finish_reason']: - stop_reason = choice['finish_reason'] + if "finish_reason" in choice and choice["finish_reason"]: + stop_reason = choice["finish_reason"] - if len(choice['delta']['content']) == 0: + if len(choice["delta"]["content"]) == 0: continue - yield BaichuanMessage(**choice['delta']) + yield BaichuanMessage(**choice["delta"]) # if there is usage, the response is the last one, yield it and return - if 'usage' in data: - message = BaichuanMessage(content='', role='assistant') + if "usage" in data: + message = BaichuanMessage(content="", role="assistant") message.usage = { - 'prompt_tokens': data['usage']['prompt_tokens'], - 'completion_tokens': data['usage']['completion_tokens'], - 'total_tokens': data['usage']['total_tokens'], + "prompt_tokens": data["usage"]["prompt_tokens"], + "completion_tokens": data["usage"]["completion_tokens"], + "total_tokens": data["usage"]["total_tokens"], } message.stop_reason = stop_reason yield message - def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage], - parameters: dict[str, Any]) \ - -> dict[str, Any]: - if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': + def _build_parameters( + self, + model: str, + stream: bool, + messages: list[BaichuanMessage], + parameters: dict[str, Any], + ) -> dict[str, Any]: + if ( + model == "baichuan2-turbo" + or model == "baichuan2-turbo-192k" + or model == "baichuan2-53b" + ): prompt_messages = [] for message in messages: - if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value: + if ( + message.role == BaichuanMessage.Role.USER.value + or message.role == BaichuanMessage.Role._SYSTEM.value + ): # check if the latest message is a user message - if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value: - prompt_messages[-1]['content'] += message.content + if ( + len(prompt_messages) > 0 + and prompt_messages[-1]["role"] + == BaichuanMessage.Role.USER.value + ): + prompt_messages[-1]["content"] += message.content else: - prompt_messages.append({ - 'content': message.content, - 'role': BaichuanMessage.Role.USER.value, - }) + prompt_messages.append( + { + "content": message.content, + "role": BaichuanMessage.Role.USER.value, + } + ) elif message.role == BaichuanMessage.Role.ASSISTANT.value: - prompt_messages.append({ - 'content': message.content, - 'role': message.role, - }) + prompt_messages.append( + { + "content": message.content, + "role": message.role, + } + ) # [baichuan] frequency_penalty must be between 1 and 2 - if 'frequency_penalty' in parameters: - if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2: - parameters['frequency_penalty'] = 1 + if "frequency_penalty" in parameters: + if ( + parameters["frequency_penalty"] < 1 + or parameters["frequency_penalty"] > 2 + ): + parameters["frequency_penalty"] = 1 # turbo api accepts flat parameters return { - 'model': self._model_mapping(model), - 'stream': stream, - 'messages': prompt_messages, + "model": self._model_mapping(model), + "stream": stream, + "messages": prompt_messages, **parameters, } else: raise BadRequestError(f"Unknown model: {model}") - + def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: - if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': + if ( + model == "baichuan2-turbo" + or model == "baichuan2-turbo-192k" + or model == "baichuan2-53b" + ): # there is no secret key for turbo api return { - 'Content-Type': 'application/json', - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ', - 'Authorization': 'Bearer ' + self.api_key, + "Content-Type": "application/json", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ", + "Authorization": "Bearer " + self.api_key, } else: raise BadRequestError(f"Unknown model: {model}") - - def _calculate_md5(self, input_string): - return md5(input_string.encode('utf-8')).hexdigest() - def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], - parameters: dict[str, Any], timeout: int) \ - -> Union[Generator, BaichuanMessage]: - - if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': - api_base = 'https://api.baichuan-ai.com/v1/chat/completions' + def _calculate_md5(self, input_string): + return md5(input_string.encode("utf-8")).hexdigest() + + def generate( + self, + model: str, + stream: bool, + messages: list[BaichuanMessage], + parameters: dict[str, Any], + timeout: int, + ) -> Union[Generator, BaichuanMessage]: + if ( + model == "baichuan2-turbo" + or model == "baichuan2-turbo-192k" + or model == "baichuan2-53b" + ): + api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: raise BadRequestError(f"Unknown model: {model}") - + try: data = self._build_parameters(model, stream, messages, parameters) headers = self._build_headers(model, data) @@ -177,35 +212,37 @@ class BaichuanModel: headers=headers, data=dumps(data), timeout=timeout, - stream=stream + stream=stream, ) except Exception as e: raise InternalServerError(f"Failed to invoke model: {e}") - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["code"] + msg = resp["error"]["message"] except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InternalServerError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': + elif err == "insufficient_quota": raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': + elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) - elif 'rate' in err: + elif "rate" in err: raise RateLimitReachedError(msg) - elif 'internal' in err: + elif "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + if stream: return self._handle_chat_stream_generate_response(response) else: diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 67d76b4a..4e56e58d 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalance(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/llm.py index 80399a30..b1691b5f 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,7 +1,11 @@ from collections.abc import Generator from typing import cast -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -17,10 +21,19 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer -from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import ( + BaichuanTokenizer, +) +from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import ( + BaichuanMessage, + BaichuanModel, +) from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, InsufficientAccountBalance, @@ -32,20 +45,43 @@ from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tu class BaichuanLarguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" + def tokens(text: str): return BaichuanTokenizer._get_num_tokens(text) @@ -57,10 +93,10 @@ class BaichuanLarguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -87,89 +123,123 @@ class BaichuanLarguageModel(LargeLanguageModel): message_dict = {"role": "user", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict def validate_credentials(self, model: str, credentials: dict) -> None: # ping instance = BaichuanModel( - api_key=credentials['api_key'], - secret_key=credentials.get('secret_key', '') + api_key=credentials["api_key"], secret_key=credentials.get("secret_key", "") ) try: - instance.generate(model=model, stream=False, messages=[ - BaichuanMessage(content='ping', role='user') - ], parameters={ - 'max_tokens': 1, - }, timeout=60) + instance.generate( + model=model, + stream=False, + messages=[BaichuanMessage(content="ping", role="user")], + parameters={ + "max_tokens": 1, + }, + timeout=60, + ) except Exception as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: if tools is not None and len(tools) > 0: raise InvokeBadRequestError("Baichuan model doesn't support tools") - + instance = BaichuanModel( - api_key=credentials['api_key'], - secret_key=credentials.get('secret_key', '') + api_key=credentials["api_key"], secret_key=credentials.get("secret_key", "") ) # convert prompt messages to baichuan messages messages = [ BaichuanMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages + content=message.content + if isinstance(message.content, str) + else "".join([content.data for content in message.content]), + role=message.role.value, + ) + for message in prompt_messages ] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + ) if stream: - return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) - - return self._handle_chat_generate_response(model, prompt_messages, credentials, response) + return self._handle_chat_generate_stream_response( + model, prompt_messages, credentials, response + ) - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: BaichuanMessage) -> LLMResult: + return self._handle_chat_generate_response( + model, prompt_messages, credentials, response + ) + + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: BaichuanMessage, + ) -> LLMResult: # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=response.content, tool_calls=[]), usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[BaichuanMessage, None, None]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[BaichuanMessage, None, None], + ) -> Generator: for message in response: if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=message.content, tool_calls=[] ), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason + if message.stop_reason + else None, ), ) else: @@ -179,10 +249,11 @@ class BaichuanLarguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=message.content, tool_calls=[] ), - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason + if message.stop_reason + else None, ), ) @@ -197,21 +268,13 @@ class BaichuanLarguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalance, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index ae706de3..0f008858 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -5,7 +5,10 @@ from typing import Optional from requests import post from model_providers.core.model_runtime.entities.model_entities import PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -14,9 +17,15 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import ( + BaichuanTokenizer, +) from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, InsufficientAccountBalance, @@ -31,11 +40,16 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for BaiChuan text embedding model. """ - api_base: str = 'http://api.baichuan-ai.com/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "http://api.baichuan-ai.com/v1/embeddings" + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -45,27 +59,24 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] - if model != 'baichuan-text-embedding': - raise ValueError('Invalid model name') + api_key = credentials["api_key"] + if model != "baichuan-text-embedding": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - + raise CredentialsValidateFailedError("api_key is required") + # split into chunks of batch size 16 chunks = [] for i in range(0, len(texts), 16): - chunks.append(texts[i:i + 16]) + chunks.append(texts[i : i + 16]) embeddings = [] token_usage = 0 for chunk in chunks: - # embeding chunk + # embedding chunk chunk_embeddings, chunk_usage = self.embedding( - model=model, - api_key=api_key, - texts=chunk, - user=user + model=model, api_key=api_key, texts=chunk, user=user ) embeddings.extend(chunk_embeddings) @@ -75,16 +86,15 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): model=model, embeddings=embeddings, usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + model=model, credentials=credentials, tokens=token_usage + ), ) return result - - def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> tuple[list[list[float]], int]: + + def embedding( + self, model: str, api_key, texts: list[str], user: Optional[str] = None + ) -> tuple[list[list[float]], int]: """ Embed given texts @@ -96,55 +106,53 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ url = self.api_base headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' + "Authorization": "Bearer " + api_key, + "Content-Type": "application/json", } - data = { - 'model': 'Baichuan-Text-Embedding', - 'input': texts - } + data = {"model": "Baichuan-Text-Embedding", "input": texts} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["code"] + msg = resp["error"]["message"] except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InternalServerError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': + elif err == "insufficient_quota": raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': + elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) - elif err and 'rate' in err: + elif err and "rate" in err: raise RateLimitReachedError(msg) - elif err and 'internal' in err: + elif err and "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - - return [ - data['embedding'] for data in embeddings - ], usage['total_tokens'] + raise InternalServerError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) + return [data["embedding"] for data in embeddings], usage["total_tokens"] def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -170,33 +178,27 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalance, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -210,7 +212,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -221,7 +223,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/bedrock/bedrock.py b/model-providers/model_providers/core/model_runtime/model_providers/bedrock/bedrock.py index 82e56ab6..27f77499 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/bedrock/bedrock.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/bedrock/bedrock.py @@ -1,11 +1,16 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) + class BedrockProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -20,11 +25,12 @@ class BedrockProvider(ModelProvider): # Use `gemini-pro` model for validate, model_instance.validate_credentials( - model='amazon.titan-text-lite-v1', - credentials=credentials + model="amazon.titan-text-lite-v1", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/bedrock/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/bedrock/llm/llm.py index c99d8d2a..48a9e990 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -13,7 +13,11 @@ from botocore.exceptions import ( UnknownServiceError, ) -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -29,18 +33,28 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) logger = logging.getLogger(__name__) -class BedrockLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: +class BedrockLargeLanguageModel(LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -55,10 +69,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + return self._generate( + model, credentials, prompt_messages, model_parameters, stop, stream, user + ) - def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + messages: list[PromptMessage] | str, + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -68,7 +89,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return:md = genai.GenerativeModel(model) """ - prefix = model.split('.')[0] + prefix = model.split(".")[0] if isinstance(messages, str): prompt = messages @@ -76,8 +97,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): prompt = self._convert_messages_to_prompt(messages, prefix) return self._get_num_tokens_by_gpt2(prompt) - - def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str: + + def _convert_messages_to_prompt( + self, model_prefix: str, messages: list[PromptMessage] + ) -> str: """ Format a list of messages into a full prompt for the Google model @@ -85,7 +108,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ messages = messages.copy() # don't mutate the original list - + text = "".join( self._convert_one_message_to_text(message, model_prefix) for message in messages @@ -101,32 +124,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param credentials: model credentials :return: """ - + try: ping_message = UserPromptMessage(content="ping") - self._generate(model=model, - credentials=credentials, - prompt_messages=[ping_message], - model_parameters={}, - stream=False) - + self._generate( + model=model, + credentials=credentials, + prompt_messages=[ping_message], + model_parameters={}, + stream=False, + ) + except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" - raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) + raise CredentialsValidateFailedError( + str(self._map_client_to_invoke_error(error_code, full_error_msg)) + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str: + def _convert_one_message_to_text( + self, message: PromptMessage, model_prefix: str + ) -> str: """ Convert a single message to a string. :param message: PromptMessage to convert. :return: String representation of the message. """ - + if model_prefix == "anthropic": human_prompt_prefix = "\n\nHuman:" human_prompt_postfix = "" @@ -141,7 +170,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): human_prompt_prefix = "\n\nUser:" human_prompt_postfix = "" ai_prompt = "\n\nBot:" - + else: human_prompt_prefix = "" human_prompt_postfix = "" @@ -160,7 +189,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], model_prefix: str + ) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -168,7 +199,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): @@ -182,23 +213,36 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + def _create_payload( + self, + model_prefix: str, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = dict() if model_prefix == "amazon": - payload["textGenerationConfig"] = { **model_parameters } - payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else []) - - payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) - + payload["textGenerationConfig"] = {**model_parameters} + payload["textGenerationConfig"]["stopSequences"] = ["User:"] + ( + stop if stop else [] + ) + + payload["inputText"] = self._convert_messages_to_prompt( + prompt_messages, model_prefix + ) + elif model_prefix == "ai21": payload["temperature"] = model_parameters.get("temperature") payload["topP"] = model_parameters.get("topP") payload["maxTokens"] = model_parameters.get("maxTokens") - payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + payload["prompt"] = self._convert_messages_to_prompt( + prompt_messages, model_prefix + ) # jurassic models only support a single stop sequence if stop: @@ -212,28 +256,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel): payload["countPenalty"] = {model_parameters.get("countPenalty")} elif model_prefix == "anthropic": - payload = { **model_parameters } - payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + payload = {**model_parameters} + payload["prompt"] = self._convert_messages_to_prompt( + prompt_messages, model_prefix + ) payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else []) - + elif model_prefix == "cohere": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = prompt_messages[0].content payload["stream"] = stream - + elif model_prefix == "meta": - payload = { **model_parameters } - payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) + payload = {**model_parameters} + payload["prompt"] = self._convert_messages_to_prompt( + prompt_messages, model_prefix + ) else: raise ValueError(f"Got unknown model prefix {model_prefix}") - + return payload - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -246,19 +300,19 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) runtime_client = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials["aws_access_key_id"], - aws_secret_access_key=credentials["aws_secret_access_key"] + aws_secret_access_key=credentials["aws_secret_access_key"], ) - model_prefix = model.split('.')[0] - payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream) + model_prefix = model.split(".")[0] + payload = self._create_payload( + model_prefix, prompt_messages, model_parameters, stop, stream + ) # need workaround for ai21 models which doesn't support streaming if stream and model_prefix != "ai21": @@ -267,18 +321,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel): invoke = runtime_client.invoke_model try: - body_jsonstr=json.dumps(payload) + body_jsonstr = json.dumps(payload) response = invoke( modelId=model, contentType="application/json", - accept= "*/*", - body=body_jsonstr + accept="*/*", + body=body_jsonstr, ) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) @@ -287,15 +341,23 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except Exception as ex: raise InvokeError(str(ex)) - if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, response, prompt_messages + ) - return self._handle_generate_response(model, credentials, response, prompt_messages) + return self._handle_generate_response( + model, credentials, response, prompt_messages + ) - def _handle_generate_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + response: dict, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -305,7 +367,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) finish_reason = response_body.get("error") @@ -313,43 +375,51 @@ class BedrockLargeLanguageModel(LargeLanguageModel): raise InvokeError(finish_reason) # get output text and calculate num tokens based on model / provider - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "amazon": - output = response_body.get("results")[0].get("outputText").strip('\n') + output = response_body.get("results")[0].get("outputText").strip("\n") prompt_tokens = response_body.get("inputTextTokenCount") completion_tokens = response_body.get("results")[0].get("tokenCount") elif model_prefix == "ai21": - output = response_body.get('completions')[0].get('data').get('text') + output = response_body.get("completions")[0].get("data").get("text") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) + completion_tokens = len( + response_body.get("completions")[0].get("data").get("tokens") + ) elif model_prefix == "anthropic": output = response_body.get("completion") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else '') - + completion_tokens = self.get_num_tokens( + model, credentials, output if output else "" + ) + elif model_prefix == "cohere": output = response_body.get("generations")[0].get("text") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else '') - + completion_tokens = self.get_num_tokens( + model, credentials, output if output else "" + ) + elif model_prefix == "meta": - output = response_body.get("generation").strip('\n') + output = response_body.get("generation").strip("\n") prompt_tokens = response_body.get("prompt_token_count") completion_tokens = response_body.get("generation_token_count") else: - raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") + raise ValueError( + f"Got unknown model prefix {model_prefix} when handling block response" + ) # construct assistant message from output - assistant_prompt_message = AssistantPromptMessage( - content=output - ) + assistant_prompt_message = AssistantPromptMessage(content=output) # calculate usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # construct response result = LLMResult( @@ -361,8 +431,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: dict, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -372,48 +447,52 @@ class BedrockLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) - content = response_body.get('completions')[0].get('data').get('text') - finish_reason = response_body.get('completions')[0].get('finish_reason') + content = response_body.get("completions")[0].get("data").get("text") + finish_reason = response_body.get("completions")[0].get("finish_reason") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + completion_tokens = len( + response_body.get("completions")[0].get("data").get("tokens") + ) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - finish_reason=finish_reason, - usage=usage - ) - ) + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content), + finish_reason=finish_reason, + usage=usage, + ), + ) return - - stream = response.get('body') + + stream = response.get("body") if not stream: - raise InvokeError('No response body') - + raise InvokeError("No response body") + index = -1 for event in stream: - chunk = event.get('chunk') - + chunk = event.get("chunk") + if not chunk: exception_name = next(iter(event)) full_ex_msg = f"{exception_name}: {event[exception_name]['message']}" raise self._map_client_to_invoke_error(exception_name, full_ex_msg) - payload = json.loads(chunk.get('bytes').decode()) + payload = json.loads(chunk.get("bytes").decode()) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "amazon": - content_delta = payload.get("outputText").strip('\n') + content_delta = payload.get("outputText").strip("\n") finish_reason = payload.get("completion_reason") - + elif model_prefix == "anthropic": content_delta = payload.get("completion") finish_reason = payload.get("stop_reason") @@ -421,38 +500,45 @@ class BedrockLargeLanguageModel(LargeLanguageModel): elif model_prefix == "cohere": content_delta = payload.get("text") finish_reason = payload.get("finish_reason") - + elif model_prefix == "meta": - content_delta = payload.get("generation").strip('\n') + content_delta = payload.get("generation").strip("\n") finish_reason = payload.get("stop_reason") - + else: - raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") + raise ValueError( + f"Got unknown model prefix {model_prefix} when handling stream response" + ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content = content_delta if content_delta else '', + content=content_delta if content_delta else "", ) index += 1 - + if not finish_reason: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + index=index, message=assistant_prompt_message + ), ) else: # get num tokens from metrics in last chunk - prompt_tokens = payload["amazon-bedrock-invocationMetrics"]["inputTokenCount"] - completion_tokens = payload["amazon-bedrock-invocationMetrics"]["outputTokenCount"] + prompt_tokens = payload["amazon-bedrock-invocationMetrics"][ + "inputTokenCount" + ] + completion_tokens = payload["amazon-bedrock-invocationMetrics"][ + "outputTokenCount" + ] # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -460,10 +546,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -479,10 +565,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - - def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: + + def _map_client_to_invoke_error( + self, error_code: str, error_msg: str + ) -> type[InvokeError]: """ Map client error to invoke error @@ -497,7 +585,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return InvokeBadRequestError(error_msg) elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in [ + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + ]: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.py b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.py index 1e48e52c..f0bd8825 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.py @@ -1,8 +1,12 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) @@ -21,11 +25,12 @@ class ChatGLMProvider(ModelProvider): # Use `chatglm3-6b` model for validate, model_instance.validate_credentials( - model='chatglm3-6b', - credentials=credentials + model="chatglm3-6b", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/llm.py index 6c58362b..1f798a26 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -20,7 +20,11 @@ from openai import ( from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_message import FunctionCall -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -37,18 +41,29 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) from model_providers.core.model_runtime.utils import helper logger = logging.getLogger(__name__) + class ChatGLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -71,11 +86,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -96,11 +116,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content="ping"), - ], model_parameters={ - "max_tokens": 16, - }) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ + UserPromptMessage(content="ping"), + ], + model_parameters={ + "max_tokens": 16, + }, + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @@ -124,24 +149,24 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -155,7 +180,9 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ - self._check_chatglm_parameters(model=model, model_parameters=model_parameters, tools=tools) + self._check_chatglm_parameters( + model=model, model_parameters=model_parameters, tools=tools + ) kwargs = self._to_client_kwargs(credentials) # init model client @@ -163,13 +190,13 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ + extra_model_kwargs["functions"] = [ helper.dump_model(tool) for tool in tools ] @@ -178,21 +205,29 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, + credentials=credentials, + response=result, + tools=tools, + prompt_messages=prompt_messages, ) - + return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, + credentials=credentials, + response=result, + tools=tools, + prompt_messages=prompt_messages, ) - - def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None: + + def _check_chatglm_parameters( + self, model: str, model_parameters: dict, tools: list[PromptMessageTool] + ) -> None: if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0: raise InvokeBadRequestError("ChatGLM2 does not support function calling") @@ -212,7 +247,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -223,12 +258,12 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): message_dict = {"role": "function", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -239,19 +274,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function + id=0, type="function", function=function ) tool_calls.append(tool_call) return tool_calls - + def _to_client_kwargs(self, credentials: dict) -> dict: """ Convert invoke kwargs to client kwargs @@ -265,17 +297,20 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": join(credentials['api_base'], 'v1') + "base_url": join(credentials["api_base"], "v1"), } return client_kwargs - - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> Generator: - - full_response = '' + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -283,35 +318,46 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and ( + delta.delta.content is None or delta.delta.content == "" + ): continue - + # check if there is a tool call in the response function_calls = None if delta.delta.function_call: function_calls = [delta.delta.function_call] - assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) + assistant_message_tool_calls = self._extract_response_tool_calls( + function_calls if function_calls else [] + ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", + tool_calls=assistant_message_tool_calls, ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) + prompt_tokens = self._num_tokens_from_messages( + messages=prompt_messages, tools=tools + ) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -320,7 +366,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -335,11 +381,15 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ) full_response += delta.delta.content - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -356,18 +406,28 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): # convert function call to tool call function_calls = assistant_message.function_call - tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) + tool_calls = self._extract_response_tool_calls( + [function_calls] if function_calls else [] + ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls + content=assistant_message.content, tool_calls=tool_calls ) - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) + prompt_tokens = self._num_tokens_from_messages( + messages=prompt_messages, tools=tools + ) + completion_tokens = self._num_tokens_from_messages( + messages=[assistant_prompt_message], tools=tools + ) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) response = LLMResult( model=model, @@ -378,8 +438,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): ) return response - - def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: + + def _num_tokens_from_string( + self, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -395,17 +457,21 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer, As a temporary solution we use GPT2 tokenizer instead. """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) - + tokens_per_message = 3 tokens_per_name = 1 num_tokens = 0 @@ -414,10 +480,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "function_call": @@ -452,36 +518,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/cohere/cohere.py b/model-providers/model_providers/core/model_runtime/model_providers/cohere/cohere.py index 6a2b91cd..f81a188c 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/cohere/cohere.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/cohere/cohere.py @@ -1,8 +1,12 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) @@ -21,11 +25,12 @@ class CohereProvider(ModelProvider): # Use `rerank-english-v2.0` model for validate, model_instance.validate_credentials( - model='rerank-english-v2.0', - credentials=credentials + model="rerank-english-v2.0", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py index ad611f11..620a5b91 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py @@ -7,7 +7,12 @@ from cohere.responses import Chat, Generations from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration from cohere.responses.generation import StreamingGenerations, StreamingText -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -17,7 +22,12 @@ from model_providers.core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelType, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -26,8 +36,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) logger = logging.getLogger(__name__) @@ -37,11 +51,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): Model class for Cohere large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -66,7 +86,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) else: return self._generate( @@ -76,11 +96,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -95,9 +120,13 @@ class CohereLargeLanguageModel(LargeLanguageModel): try: if model_mode == LLMMode.CHAT: - return self._num_tokens_from_messages(model, credentials, prompt_messages) + return self._num_tokens_from_messages( + model, credentials, prompt_messages + ) else: - return self._num_tokens_from_string(model, credentials, prompt_messages[0].content) + return self._num_tokens_from_string( + model, credentials, prompt_messages[0].content + ) except Exception as e: raise self._transform_invoke_error(e) @@ -117,30 +146,37 @@ class CohereLargeLanguageModel(LargeLanguageModel): self._chat_generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) else: self._generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -154,10 +190,10 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get("api_key")) if stop: - model_parameters['end_sequences'] = stop + model_parameters["end_sequences"] = stop response = client.generate( prompt=prompt_messages[0].content, @@ -167,13 +203,21 @@ class CohereLargeLanguageModel(LargeLanguageModel): ) if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, response, prompt_messages + ) - return self._handle_generate_response(model, credentials, response, prompt_messages) + return self._handle_generate_response( + model, credentials, response, prompt_messages + ) - def _handle_generate_response(self, model: str, credentials: dict, response: Generations, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + response: Generations, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -186,29 +230,34 @@ class CohereLargeLanguageModel(LargeLanguageModel): assistant_text = response.generations[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens - prompt_tokens = response.meta['billed_units']['input_tokens'] - completion_tokens = response.meta['billed_units']['output_tokens'] + prompt_tokens = response.meta["billed_units"]["input_tokens"] + completion_tokens = response.meta["billed_units"]["output_tokens"] # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response response = LLMResult( model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, - usage=usage + usage=usage, ) return response - def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: StreamingGenerations, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -218,7 +267,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: llm response chunk generator """ index = 1 - full_assistant_content = '' + full_assistant_content = "" for chunk in response: if isinstance(chunk, StreamingText): chunk = cast(StreamingText, chunk) @@ -228,9 +277,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -240,33 +287,42 @@ class CohereLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 elif chunk is None: # calculate num tokens - prompt_tokens = response.meta['billed_units']['input_tokens'] - completion_tokens = response.meta['billed_units']['output_tokens'] + prompt_tokens = response.meta["billed_units"]["input_tokens"] + completion_tokens = response.meta["billed_units"]["output_tokens"] # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), + message=AssistantPromptMessage(content=""), finish_reason=response.finish_reason, - usage=usage - ) + usage=usage, + ), ) break - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -280,17 +336,23 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get("api_key")) if user: - model_parameters['user_name'] = user + model_parameters["user_name"] = user - message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + ( + message, + chat_histories, + ) = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) # chat model real_model = model - if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + if ( + self.get_model_schema(model, credentials).fetch_from + == FetchFrom.PREDEFINED_MODEL + ): + real_model = model.removesuffix("-chat") response = client.chat( message=message, @@ -302,13 +364,22 @@ class CohereLargeLanguageModel(LargeLanguageModel): ) if stream: - return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop) + return self._handle_chat_generate_stream_response( + model, credentials, response, prompt_messages, stop + ) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) + return self._handle_chat_generate_response( + model, credentials, response, prompt_messages, stop + ) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, - prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \ - -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: Chat, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -322,23 +393,25 @@ class CohereLargeLanguageModel(LargeLanguageModel): assistant_text = response.text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) - completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message]) + prompt_tokens = self._num_tokens_from_messages( + model, credentials, prompt_messages + ) + completion_tokens = self._num_tokens_from_messages( + model, credentials, [assistant_prompt_message] + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) if stop: # enforce stop tokens assistant_text = self.enforce_stop_tokens(assistant_text, stop) - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # transform response response = LLMResult( @@ -346,14 +419,19 @@ class CohereLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage, - system_fingerprint=response.preamble + system_fingerprint=response.preamble, ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: StreamingChat, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -364,18 +442,26 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: llm response chunk generator """ - def final_response(full_text: str, index: int, finish_reason: Optional[str] = None, - preamble: Optional[str] = None) -> LLMResultChunk: + def final_response( + full_text: str, + index: int, + finish_reason: Optional[str] = None, + preamble: Optional[str] = None, + ) -> LLMResultChunk: # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) - - full_assistant_prompt_message = AssistantPromptMessage( - content=full_text + prompt_tokens = self._num_tokens_from_messages( + model, credentials, prompt_messages + ) + + full_assistant_prompt_message = AssistantPromptMessage(content=full_text) + completion_tokens = self._num_tokens_from_messages( + model, credentials, [full_assistant_prompt_message] ) - completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) return LLMResultChunk( model=model, @@ -383,14 +469,14 @@ class CohereLargeLanguageModel(LargeLanguageModel): system_fingerprint=preamble, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), + message=AssistantPromptMessage(content=""), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) index = 1 - full_assistant_content = '' + full_assistant_content = "" for chunk in response: if isinstance(chunk, StreamTextGeneration): chunk = cast(StreamTextGeneration, chunk) @@ -400,14 +486,12 @@ class CohereLargeLanguageModel(LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) # stop # notice: This logic can only cover few stop scenarios if stop and text in stop: - yield final_response(full_assistant_content, index, 'stop') + yield final_response(full_assistant_content, index, "stop") break full_assistant_content += text @@ -418,17 +502,23 @@ class CohereLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 elif isinstance(chunk, StreamEnd): chunk = cast(StreamEnd, chunk) - yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble) + yield final_response( + full_assistant_content, + index, + chunk.finish_reason, + response.preamble, + ) index += 1 - def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> tuple[str, list[dict]]: + def _convert_prompt_messages_to_message_and_chat_histories( + self, prompt_messages: list[PromptMessage] + ) -> tuple[str, list[dict]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -441,9 +531,9 @@ class CohereLargeLanguageModel(LargeLanguageModel): # get latest message from chat histories and pop it if len(chat_histories) > 0: latest_message = chat_histories.pop() - message = latest_message['message'] + message = latest_message["message"] else: - raise ValueError('Prompt messages is empty') + raise ValueError("Prompt messages is empty") return message, chat_histories @@ -456,10 +546,12 @@ class CohereLargeLanguageModel(LargeLanguageModel): if isinstance(message.content, str): message_dict = {"role": "USER", "message": message.content} else: - sub_message_text = '' + sub_message_text = "" for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) sub_message_text += message_content.data message_dict = {"role": "USER", "message": sub_message_text} @@ -487,47 +579,53 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: number of tokens """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get("api_key")) - response = client.tokenize( - text=text, - model=model - ) + response = client.tokenize(text=text, model=model) return response.length - def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int: + def _num_tokens_from_messages( + self, model: str, credentials: dict, messages: list[PromptMessage] + ) -> int: """Calculate num tokens Cohere model.""" messages = [self._convert_prompt_message_to_dict(m) for m in messages] - message_strs = [f"{message['role']}: {message['message']}" for message in messages] + message_strs = [ + f"{message['role']}: {message['message']}" for message in messages + ] message_str = "\n".join(message_strs) real_model = model - if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + if ( + self.get_model_schema(model, credentials).fetch_from + == FetchFrom.PREDEFINED_MODEL + ): + real_model = model.removesuffix("-chat") return self._num_tokens_from_string(real_model, credentials, message_str) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: """ - Cohere supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + Cohere supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} - mode = credentials.get('mode') + mode = credentials.get("mode") - if mode == 'chat': - base_model_schema = model_map['command-light-chat'] + if mode == "chat": + base_model_schema = model_map["command-light-chat"] else: - base_model_schema = model_map['command-light'] + base_model_schema = model_map["command-light"] base_model_schema = cast(AIModelEntity, base_model_schema) @@ -537,18 +635,16 @@ class CohereLargeLanguageModel(LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=[feature for feature in base_model_schema_features], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - key: property for key, property in base_model_schema_model_properties.items() + key: property + for key, property in base_model_schema_model_properties.items() }, parameter_rules=[rule for rule in base_model_schema_parameters_rules], - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) return entity @@ -564,14 +660,12 @@ class CohereLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.CohereConnectionError - ], + InvokeConnectionError: [cohere.CohereConnectionError], InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], InvokeBadRequestError: [ cohere.CohereAPIError, cohere.CohereError, - ] + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/cohere/rerank/rerank.py b/model-providers/model_providers/core/model_runtime/model_providers/cohere/rerank/rerank.py index b3691ee2..86d0b22f 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -2,7 +2,10 @@ from typing import Optional import cohere -from model_providers.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from model_providers.core.model_runtime.entities.rerank_entities import ( + RerankDocument, + RerankResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -11,8 +14,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.rerank_model import ( + RerankModel, +) class CohereRerankModel(RerankModel): @@ -20,10 +27,16 @@ class CohereRerankModel(RerankModel): Model class for Cohere rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -37,26 +50,18 @@ class CohereRerankModel(RerankModel): :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) # initialize client - client = cohere.Client(credentials.get('api_key')) - results = client.rerank( - query=query, - documents=docs, - model=model, - top_n=top_n - ) + client = cohere.Client(credentials.get("api_key")) + results = client.rerank(query=query, documents=docs, model=model, top_n=top_n) rerank_documents = [] for idx, result in enumerate(results): # format document rerank_document = RerankDocument( index=result.index, - text=result.document['text'], + text=result.document["text"], score=result.relevance_score, ) @@ -67,10 +72,7 @@ class CohereRerankModel(RerankModel): else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -91,7 +93,7 @@ class CohereRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -116,5 +118,5 @@ class CohereRerankModel(RerankModel): InvokeBadRequestError: [ cohere.CohereAPIError, cohere.CohereError, - ] + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 56a73601..bf4821e9 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -6,7 +6,10 @@ import numpy as np from cohere.responses import Tokens from model_providers.core.model_runtime.entities.model_entities import PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -15,8 +18,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) class CohereTextEmbeddingModel(TextEmbeddingModel): @@ -24,9 +31,13 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -47,13 +58,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): for i, text in enumerate(texts): tokenize_response = self._tokenize( - model=model, - credentials=credentials, - text=text + model=model, credentials=credentials, text=text ) for j in range(0, tokenize_response.length, context_size): - tokens += [tokenize_response.token_strings[j: j + context_size]] + tokens += [tokenize_response.token_strings[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -64,7 +73,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): embeddings_batch, embedding_used_tokens = self._embedding_invoke( model=model, credentials=credentials, - texts=["".join(token) for token in tokens[i: i + max_chunks]] + texts=["".join(token) for token in tokens[i : i + max_chunks]], ) used_tokens += embedding_used_tokens @@ -80,9 +89,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=[" "] + model=model, credentials=credentials, texts=[" "] ) used_tokens += embedding_used_tokens @@ -93,16 +100,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): # calc usage usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens + model=model, credentials=credentials, tokens=used_tokens ) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -116,13 +117,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): if len(texts) == 0: return 0 - full_text = ' '.join(texts) + full_text = " ".join(texts) try: response = self._tokenize( - model=model, - credentials=credentials, - text=full_text + model=model, credentials=credentials, text=full_text ) except Exception as e: raise self._transform_invoke_error(e) @@ -141,12 +140,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): return Tokens([], [], {}) # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get("api_key")) - response = client.tokenize( - text=text, - model=model - ) + response = client.tokenize(text=text, model=model) return response @@ -160,15 +156,13 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]: + def _embedding_invoke( + self, model: str, credentials: dict, texts: list[str] + ) -> tuple[list[list[float]], int]: """ Invoke embedding model @@ -178,18 +172,20 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: embeddings and used tokens """ # initialize client - client = cohere.Client(credentials.get('api_key')) + client = cohere.Client(credentials.get("api_key")) # call embedding model response = client.embed( texts=texts, model=model, - input_type='search_document' if len(texts) > 1 else 'search_query' + input_type="search_document" if len(texts) > 1 else "search_query", ) - return response.embeddings, response.meta['billed_units']['input_tokens'] + return response.embeddings, response.meta["billed_units"]["input_tokens"] - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -203,7 +199,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -214,7 +210,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -230,14 +226,12 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.CohereConnectionError - ], + InvokeConnectionError: [cohere.CohereConnectionError], InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], InvokeBadRequestError: [ cohere.CohereAPIError, cohere.CohereError, - ] + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/google/google.py b/model-providers/model_providers/core/model_runtime/model_providers/google/google.py index fa426593..88c8e62b 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/google/google.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/google/google.py @@ -1,8 +1,12 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) @@ -21,11 +25,12 @@ class GoogleProvider(ModelProvider): # Use `gemini-pro` model for validate, model_instance.validate_credentials( - model='gemini-pro', - credentials=credentials + model="gemini-pro", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py index 4fecd526..b3a08d87 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py @@ -5,10 +5,19 @@ from typing import Optional, Union import google.api_core.exceptions as exceptions import google.generativeai as genai import google.generativeai.client as client -from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory +from google.generativeai.types import ( + ContentType, + GenerateContentResponse, + HarmBlockThreshold, + HarmCategory, +) from google.generativeai.types.content_types import to_part -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -26,8 +35,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) logger = logging.getLogger(__name__) @@ -42,12 +55,17 @@ if you are not sure about the structure. class GoogleLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -62,10 +80,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + return self._generate( + model, credentials, prompt_messages, model_parameters, stop, stream, user + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -89,8 +114,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): messages = messages.copy() # don't mutate the original list text = "".join( - self._convert_one_message_to_text(message) - for message in messages + self._convert_one_message_to_text(message) for message in messages ) return text.rstrip() @@ -106,16 +130,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel): try: ping_message = PromptMessage(content="ping", role="system") - self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) + self._generate( + model, credentials, [ping_message], {"max_tokens_to_sample": 5} + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -129,14 +160,14 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop( + "max_tokens_to_sample", None + ) if stop: config_kwargs["stop_sequences"] = stop - google_model = genai.GenerativeModel( - model_name=model - ) + google_model = genai.GenerativeModel(model_name=model) history = [] @@ -146,14 +177,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel): content = self._format_message_to_glm_content(last_msg) history.append(content) else: - for msg in prompt_messages: # makes message roles strictly alternating + for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) else: history.append(content) - # Create a new ClientManager with tenant's API key new_client_manager = client._ClientManager() new_client_manager.configure(api_key=credentials["google_api_key"]) @@ -161,7 +191,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): google_model._client = new_custom_client - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, @@ -170,20 +200,27 @@ class GoogleLargeLanguageModel(LargeLanguageModel): response = google_model.generate_content( contents=history, - generation_config=genai.types.GenerationConfig( - **config_kwargs - ), + generation_config=genai.types.GenerationConfig(**config_kwargs), stream=stream, - safety_settings=safety_settings + safety_settings=safety_settings, ) if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, response, prompt_messages + ) - return self._handle_generate_response(model, credentials, response, prompt_messages) + return self._handle_generate_response( + model, credentials, response, prompt_messages + ) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + response: GenerateContentResponse, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -194,16 +231,18 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response result = LLMResult( @@ -215,8 +254,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: GenerateContentResponse, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -232,28 +276,29 @@ class GoogleLargeLanguageModel(LargeLanguageModel): index += 1 assistant_prompt_message = AssistantPromptMessage( - content=content if content else '', + content=content if content else "", ) if not response._done: - # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + index=index, message=assistant_prompt_message + ), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=model, @@ -262,8 +307,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel): index=index, message=assistant_prompt_message, finish_reason=chunk.candidates[0].finish_reason, - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -302,21 +347,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ parts = [] - if (isinstance(message.content, str)): + if isinstance(message.content, str): parts.append(to_part(message.content)) else: for c in message.content: if c.type == PromptMessageContentType.TEXT: parts.append(to_part(c.data)) else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] - blob = {"inline_data":{"mime_type":mime_type,"data":data}} + metadata, data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] + blob = {"inline_data": {"mime_type": mime_type, "data": data}} parts.append(blob) glm_content = { - "role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model", - "parts": parts + "role": "user" + if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) + else "model", + "parts": parts, } return glm_content @@ -332,25 +379,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel): :return: Invoke emd = genai.GenerativeModel(model)rror mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded + exceptions.DeadlineExceeded, ], InvokeRateLimitError: [ exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.TooManyRequests, ], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -366,5 +411,5 @@ class GoogleLargeLanguageModel(LargeLanguageModel): exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/groq/groq.py b/model-providers/model_providers/core/model_runtime/model_providers/groq/groq.py index b4dca94b..c79e0c5c 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/groq/groq.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/groq/groq.py @@ -1,13 +1,17 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) -class GroqProvider(ModelProvider): +class GroqProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,11 +23,12 @@ class GroqProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) model_instance.validate_credentials( - model='llama2-70b-4096', - credentials=credentials + model="llama2-70b-4096", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/groq/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/groq/llm/llm.py index 73e3894f..58a76581 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/groq/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/groq/llm/llm.py @@ -2,18 +2,31 @@ from collections.abc import Generator from typing import Optional, Union from model_providers.core.model_runtime.entities.llm_entities import LLMResult -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( + OAIAPICompatLargeLanguageModel, +) class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + return super()._invoke( + model, credentials, prompt_messages, model_parameters, tools, stop, stream + ) def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) @@ -21,6 +34,5 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.groq.com/openai/v1' - + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/_common.py index ecfa9008..e14f2653 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -1,15 +1,12 @@ from huggingface_hub.utils import BadRequestError, HfHubHTTPError -from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError +from model_providers.core.model_runtime.errors.invoke import ( + InvokeBadRequestError, + InvokeError, +) class _CommonHuggingfaceHub: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - HfHubHTTPError, - BadRequestError - ] - } + return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]} diff --git a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py index 027fc87f..1420e82a 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py @@ -1,11 +1,12 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class HuggingfaceHubProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index d427492a..d47f0461 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -7,7 +7,12 @@ from huggingface_hub.utils import BadRequestError from model_providers.core.model_runtime.entities.common_entities import I18nObject from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -23,22 +28,35 @@ from model_providers.core.model_runtime.entities.model_entities import ( ModelType, ParameterRule, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.huggingface_hub._common import ( + _CommonHuggingfaceHub, +) class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] - - if 'baichuan' in model.lower(): + if "baichuan" in model.lower(): stream = False response = client.text_generation( @@ -47,71 +65,97 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel stream=stream, model=model, stop_sequences=stop, - **model_parameters) + **model_parameters, + ) if stream: - return self._handle_generate_stream_response(model, credentials, prompt_messages, response) + return self._handle_generate_stream_response( + model, credentials, prompt_messages, response + ) - return self._handle_generate_response(model, credentials, prompt_messages, response) + return self._handle_generate_response( + model, credentials, prompt_messages, response + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub Endpoint Type must be provided." + ) - if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'): - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + if credentials["huggingfacehub_api_type"] not in ( + "inference_endpoints", + "hosted_inference_api", + ): + raise CredentialsValidateFailedError( + "Huggingface Hub Endpoint Type is invalid." + ) - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub Access Token must be provided." + ) - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub Endpoint URL must be provided." + ) - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + if "task_type" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type must be provided." + ) + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + credentials["task_type"] = self._get_hosted_model_task_type( + credentials["huggingfacehub_api_token"], model + ) - if credentials['task_type'] not in ("text2text-generation", "text-generation"): - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, ' - 'text-generation.') + if credentials["task_type"] not in ( + "text2text-generation", + "text-generation", + ): + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type must be one of text2text-generation, " + "text-generation." + ) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] try: - client.text_generation( - prompt='Who are you?', - stream=True, - model=model) + client.text_generation(prompt="Who are you?", stream=True, model=model) except BadRequestError as e: - raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. ' - 'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.') + raise CredentialsValidateFailedError( + "Only available for models running on with the `text-generation-inference`. " + "To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference." + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: LLMMode.COMPLETION.value - }, - parameter_rules=self._get_customizable_model_parameter_rules() + model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value}, + parameter_rules=self._get_customizable_model_parameter_rules(), ) return entity @@ -119,26 +163,27 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel @staticmethod def _get_customizable_model_parameter_rules() -> list[ParameterRule]: temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get( - DefaultParameterName.TEMPERATURE).copy() - temperature_rule_dict['name'] = 'temperature' + DefaultParameterName.TEMPERATURE + ).copy() + temperature_rule_dict["name"] = "temperature" temperature_rule = ParameterRule(**temperature_rule_dict) temperature_rule.default = 0.5 top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy() - top_p_rule_dict['name'] = 'top_p' + top_p_rule_dict["name"] = "top_p" top_p_rule = ParameterRule(**top_p_rule_dict) top_p_rule.default = 0.5 top_k_rule = ParameterRule( - name='top_k', + name="top_k", label={ - 'en_US': 'Top K', - 'zh_Hans': 'Top K', + "en_US": "Top K", + "zh_Hans": "Top K", }, - type='int', + type="int", help={ - 'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.', - 'zh_Hans': '保留的最高概率词汇标记的数量。', + "en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", + "zh_Hans": "保留的最高概率词汇标记的数量。", }, required=False, default=2, @@ -148,15 +193,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ) max_new_tokens = ParameterRule( - name='max_new_tokens', + name="max_new_tokens", label={ - 'en_US': 'Max New Tokens', - 'zh_Hans': '最大新标记', + "en_US": "Max New Tokens", + "zh_Hans": "最大新标记", }, - type='int', + type="int", help={ - 'en_US': 'Maximum number of generated tokens.', - 'zh_Hans': '生成的标记的最大数量。', + "en_US": "Maximum number of generated tokens.", + "zh_Hans": "生成的标记的最大数量。", }, required=False, default=20, @@ -166,42 +211,51 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ) seed = ParameterRule( - name='seed', + name="seed", label={ - 'en_US': 'Random sampling seed', - 'zh_Hans': '随机采样种子', + "en_US": "Random sampling seed", + "zh_Hans": "随机采样种子", }, - type='int', + type="int", help={ - 'en_US': 'Random sampling seed.', - 'zh_Hans': '随机采样种子。', + "en_US": "Random sampling seed.", + "zh_Hans": "随机采样种子。", }, required=False, precision=0, ) repetition_penalty = ParameterRule( - name='repetition_penalty', + name="repetition_penalty", label={ - 'en_US': 'Repetition Penalty', - 'zh_Hans': '重复惩罚', + "en_US": "Repetition Penalty", + "zh_Hans": "重复惩罚", }, - type='float', + type="float", help={ - 'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.', - 'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。', + "en_US": "The parameter for repetition penalty. 1.0 means no penalty.", + "zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。", }, required=False, precision=1, ) - return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty] + return [ + temperature_rule, + top_k_rule, + top_p_rule, + max_new_tokens, + seed, + repetition_penalty, + ] - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - response: Generator) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + response: Generator, + ) -> Generator: index = -1 for chunk in response: # skip special tokens @@ -210,15 +264,17 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=chunk.token.text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text) if chunk.details: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=model, @@ -240,20 +296,28 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel ), ) - def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + response: any, + ) -> LLMResult: if isinstance(response, str): content = response else: content = response.generated_text - assistant_prompt_message = AssistantPromptMessage( - content=content - ) + assistant_prompt_message = AssistantPromptMessage(content=content) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) result = LLMResult( model=model, @@ -270,15 +334,22 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if ( + "inference" in model_info.cardData + and not model_info.cardData["inference"] + ): + raise ValueError( + f"Inference API has been turned off for this model {model_name}." + ) valid_tasks = ("text2text-generation", "text-generation") if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError( + f"Model {model_name} is not a valid task, " + f"must be one of {valid_tasks}." + ) except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") @@ -288,8 +359,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel messages = messages.copy() # don't mutate the original list text = "".join( - self._convert_one_message_to_text(message) - for message in messages + self._convert_one_message_to_text(message) for message in messages ) return text.rstrip() diff --git a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 2b3297f8..a0451017 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -7,35 +7,51 @@ import requests from huggingface_hub import HfApi, InferenceClient from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelType, + PriceType, +) +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.huggingface_hub._common import ( + _CommonHuggingfaceHub, +) -HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' +HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/" class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) execute_model = model - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - execute_model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + execute_model = credentials["huggingfacehub_endpoint_url"] output = client.post( json={ "inputs": texts, - "options": { - "wait_for_model": False, - "use_cache": False - } + "options": {"wait_for_model": False, "use_cache": False}, }, - model=execute_model) + model=execute_model, + ) embeddings = json.loads(output.decode()) @@ -43,9 +59,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel usage = self._calc_response_usage(model, credentials, tokens) return TextEmbeddingResult( - embeddings=self._mean_pooling(embeddings), - usage=usage, - model=model + embeddings=self._mean_pooling(embeddings), usage=usage, model=model ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -56,52 +70,64 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub Endpoint Type must be provided." + ) - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub API Token must be provided." + ) - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingface_namespace' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingface_namespace" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub User Name / Organization Name must be provided." + ) - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub Endpoint URL must be provided." + ) - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') + if "task_type" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type must be provided." + ) - if credentials['task_type'] != 'feature-extraction': - raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.') + if credentials["task_type"] != "feature-extraction": + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type is invalid." + ) self._check_endpoint_url_model_repository_name(credentials, model) - model = credentials['huggingfacehub_endpoint_url'] + model = credentials["huggingfacehub_endpoint_url"] - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + self._check_hosted_model_task_type( + credentials["huggingfacehub_api_token"], model + ) else: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + raise CredentialsValidateFailedError( + "Huggingface Hub Endpoint Type is invalid." + ) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) - client.feature_extraction(text='hello world', model=model) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) + client.feature_extraction(text="hello world", model=model) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 10000, - 'max_chunks': 1 - } + model_properties={"context_size": 10000, "max_chunks": 1}, ) return entity @@ -118,34 +144,47 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel return embeddings # For example two: List[List[List[float]]], need to mean_pooling. - sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings] + sentence_embeddings = [ + np.mean(embedding[0], axis=0).tolist() for embedding in embeddings + ] return sentence_embeddings @staticmethod - def _check_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str) -> None: + def _check_hosted_model_task_type( + huggingfacehub_api_token: str, model_name: str + ) -> None: hf_api = HfApi(token=huggingfacehub_api_token) model_info = hf_api.model_info(repo_id=model_name) try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if ( + "inference" in model_info.cardData + and not model_info.cardData["inference"] + ): + raise ValueError( + f"Inference API has been turned off for this model {model_name}." + ) valid_tasks = "feature-extraction" if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError( + f"Model {model_name} is not a valid task, " + f"must be one of {valid_tasks}." + ) except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: input_price_info = self.get_price( model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -156,7 +195,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -166,25 +205,29 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel try: url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}' headers = { - 'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}', - 'Content-Type': 'application/json' + "Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}', + "Content-Type": "application/json", } response = requests.get(url=url, headers=headers) if response.status_code != 200: - raise ValueError('User Name or Organization Name is invalid.') + raise ValueError("User Name or Organization Name is invalid.") - model_repository_name = '' + model_repository_name = "" for item in response.json().get("items", []): - if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']: + if ( + item.get("status", {}).get("url") + == credentials["huggingfacehub_endpoint_url"] + ): model_repository_name = item.get("model", {}).get("repository") break if model_repository_name != model_name: raise ValueError( - f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.') + f"Model Name {model_name} is invalid. Please check it on the inference endpoints console." + ) except Exception as e: raise ValueError(str(e)) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/jina/jina.py b/model-providers/model_providers/core/model_runtime/model_providers/jina/jina.py index 65cf6fc1..ea9e80d3 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/jina/jina.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/jina/jina.py @@ -1,14 +1,17 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class JinaProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -22,11 +25,12 @@ class JinaProvider(ModelProvider): # Use `jina-embeddings-v2-base-en` model for validate, # no matter what model you pass in, text completion model or chat model model_instance.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials=credentials + model="jina-embeddings-v2-base-en", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/jina/rerank/rerank.py b/model-providers/model_providers/core/model_runtime/model_providers/jina/rerank/rerank.py index 18137c69..09a3b2fa 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -2,7 +2,10 @@ from typing import Optional import httpx -from model_providers.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from model_providers.core.model_runtime.entities.rerank_entities import ( + RerankDocument, + RerankResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -11,8 +14,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.rerank_model import ( + RerankModel, +) class JinaRerankModel(RerankModel): @@ -20,9 +27,16 @@ class JinaRerankModel(RerankModel): Model class for Jina rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -45,26 +59,29 @@ class JinaRerankModel(RerankModel): "model": model, "query": query, "documents": docs, - "top_n": top_n + "top_n": top_n, }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) - response.raise_for_status() + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if ( + score_threshold is None + or result["relevance_score"] >= score_threshold + ): rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -75,7 +92,6 @@ class JinaRerankModel(RerankModel): :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -86,7 +102,7 @@ class JinaRerankModel(RerankModel): "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -99,7 +115,7 @@ class JinaRerankModel(RerankModel): return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index 50f8c73e..d80cbfa8 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -14,19 +14,19 @@ class JinaTokenizer: with cls._lock: if cls._tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + gpt2_tokenizer_path = join(dirname(base_path), "tokenizer") cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) return cls._tokenizer @classmethod def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ - use jina tokenizer to get num tokens + use jina tokenizer to get num tokens """ tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - + @classmethod def get_num_tokens(cls, text: str) -> int: - return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file + return cls._get_num_tokens_by_jina_base(text) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 6815c6a7..48d8fecb 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -5,7 +5,10 @@ from typing import Optional from requests import post from model_providers.core.model_runtime.entities.model_entities import PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -14,21 +17,37 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import JinaTokenizer +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import ( + JinaTokenizer, +) class JinaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - api_base: str = 'https://api.jina.ai/v1/embeddings' - models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de'] - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.jina.ai/v1/embeddings" + models: list[str] = [ + "jina-embeddings-v2-base-en", + "jina-embeddings-v2-small-en", + "jina-embeddings-v2-base-zh", + "jina-embeddings-v2-base-de", + ] + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -38,31 +57,28 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if model not in self.models: - raise InvokeBadRequestError('Invalid model name') + raise InvokeBadRequestError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") url = self.api_base headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' + "Authorization": "Bearer " + api_key, + "Content-Type": "application/json", } - data = { - 'model': model, - 'input': texts - } + data = {"model": model, "input": texts} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -72,23 +88,27 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=usage["total_tokens"] + ) result = TextEmbeddingResult( model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], + usage=usage, ) return result @@ -117,31 +137,23 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -155,7 +167,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -166,7 +178,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/localai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/localai/llm/llm.py index 3c5545b9..0545b52d 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/localai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/localai/llm/llm.py @@ -21,7 +21,12 @@ from openai.types.completion import Completion from yarl import URL from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -45,34 +50,60 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) from model_providers.core.model_runtime.utils import helper class LocalAILarguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool] + ) -> int: """ - Calculate num tokens for baichuan model - LocalAI does not supports + Calculate num tokens for baichuan model + LocalAI does not supports """ + def tokens(text: str): """ - We cloud not determine which tokenizer to use, cause the model is customized. - So we use gpt2 tokenizer to calculate the num tokens for convenience. + We cloud not determine which tokenizer to use, cause the model is customized. + So we use gpt2 tokenizer to calculate the num tokens for convenience. """ return self._get_num_tokens_by_gpt2(text) @@ -85,10 +116,10 @@ class LocalAILarguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -124,7 +155,7 @@ class LocalAILarguageModel(LargeLanguageModel): num_tokens += self._num_tokens_for_tools(tools) return num_tokens - + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for tool calling @@ -133,36 +164,37 @@ class LocalAILarguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -177,141 +209,166 @@ class LocalAILarguageModel(LargeLanguageModel): :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={ - 'max_tokens': 10, - }, stop=[], stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={ + "max_tokens": 10, + }, + stop=[], + stream=False, + ) except Exception as ex: - raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') + raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}") - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity | None: completion_model = None - if credentials['completion_type'] == 'chat_completion': + if credentials["completion_type"] == "chat_completion": completion_model = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_model = LLMMode.COMPLETION.value else: - raise ValueError(f"Unknown completion type {credentials['completion_type']}") - + raise ValueError( + f"Unknown completion type {credentials['completion_type']}" + ) + rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=2048, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] - model_properties = { - ModelPropertyKey.MODE: completion_model, - } if completion_model else {} + model_properties = ( + { + ModelPropertyKey.MODE: completion_model, + } + if completion_model + else {} + ) - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( + credentials.get("context_size", "2048") + ) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, - parameter_rules=rules + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) # init model client client = OpenAI(**kwargs) model_name = model - completion_type = credentials['completion_type'] + completion_type = credentials["completion_type"] extra_model_kwargs = { "timeout": 60, } if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ + extra_model_kwargs["functions"] = [ helper.dump_model(tool) for tool in tools ] - - if completion_type == 'chat_completion': + + if completion_type == "chat_completion": result = client.chat.completions.create( - messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], + messages=[ + self._convert_prompt_message_to_dict(m) for m in prompt_messages + ], model=model_name, stream=stream, **model_parameters, **extra_model_kwargs, ) - elif completion_type == 'completion': + elif completion_type == "completion": result = client.completions.create( - prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages), + prompt=self._convert_prompt_message_to_completion_prompts( + prompt_messages + ), model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: raise ValueError(f"Unknown completion type {completion_type}") if stream: - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, + credentials=credentials, + response=result, + tools=tools, + prompt_messages=prompt_messages, ) return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, + credentials=credentials, + response=result, + tools=tools, + prompt_messages=prompt_messages, ) - - if completion_type == 'completion': + + if completion_type == "completion": return self._handle_completion_generate_response( - model=model, credentials=credentials, response=result, - prompt_messages=prompt_messages + model=model, + credentials=credentials, + response=result, + prompt_messages=prompt_messages, ) return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, + credentials=credentials, + response=result, + tools=tools, + prompt_messages=prompt_messages, ) - + def _to_client_kwargs(self, credentials: dict) -> dict: """ Convert invoke kwargs to client kwargs @@ -319,13 +376,13 @@ class LocalAILarguageModel(LargeLanguageModel): :param credentials: credentials dict :return: client kwargs """ - if not credentials['server_url'].endswith('/'): - credentials['server_url'] += '/' - + if not credentials["server_url"].endswith("/"): + credentials["server_url"] += "/" + client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['server_url']) / 'v1'), + "base_url": str(URL(credentials["server_url"]) / "v1"), } return client_kwargs @@ -346,41 +403,45 @@ class LocalAILarguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMessage]) -> str: + def _convert_prompt_message_to_completion_prompts( + self, messages: list[PromptMessage] + ) -> str: """ Convert PromptMessage to completion prompts """ - prompts = '' + prompts = "" for message in messages: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" else: raise ValueError(f"Unknown message type {type(message)}") - + return prompts - def _handle_completion_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Completion, - ) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Completion, + ) -> LLMResult: """ Handle llm chat response @@ -393,21 +454,27 @@ class LocalAILarguageModel(LargeLanguageModel): """ if len(response.choices) == 0: raise InvokeServerUnavailableError("Empty response") - + assistant_message = response.choices[0].text # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] + content=assistant_message, tool_calls=[] ) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) ) - completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[]) + completion_tokens = self._num_tokens_from_messages( + messages=[assistant_prompt_message], tools=[] + ) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) response = LLMResult( model=model, @@ -419,11 +486,14 @@ class LocalAILarguageModel(LargeLanguageModel): return response - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ChatCompletion, - tools: list[PromptMessageTool]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: ChatCompletion, + tools: list[PromptMessageTool], + ) -> LLMResult: """ Handle llm chat response @@ -436,23 +506,33 @@ class LocalAILarguageModel(LargeLanguageModel): """ if len(response.choices) == 0: raise InvokeServerUnavailableError("Empty response") - + assistant_message = response.choices[0].message # convert function call to tool call function_calls = assistant_message.function_call - tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) + tool_calls = self._extract_response_tool_calls( + [function_calls] if function_calls else [] + ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls + content=assistant_message.content, tool_calls=tool_calls ) - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) + prompt_tokens = self._num_tokens_from_messages( + messages=prompt_messages, tools=tools + ) + completion_tokens = self._num_tokens_from_messages( + messages=[assistant_prompt_message], tools=tools + ) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) response = LLMResult( model=model, @@ -464,12 +544,15 @@ class LocalAILarguageModel(LargeLanguageModel): return response - def _handle_completion_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[Completion], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_completion_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[Completion], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -479,26 +562,30 @@ class LocalAILarguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] + content=delta.text if delta.text else "", tool_calls=[] ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] + content=full_response, tool_calls=[] ) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) ) - completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -507,7 +594,7 @@ class LocalAILarguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -523,12 +610,15 @@ class LocalAILarguageModel(LargeLanguageModel): full_response += delta.text - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[ChatCompletionChunk], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[ChatCompletionChunk], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -536,35 +626,46 @@ class LocalAILarguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and ( + delta.delta.content is None or delta.delta.content == "" + ): continue - + # check if there is a tool call in the response function_calls = None if delta.delta.function_call: function_calls = [delta.delta.function_call] - assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) + assistant_message_tool_calls = self._extract_response_tool_calls( + function_calls if function_calls else [] + ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", + tool_calls=assistant_message_tool_calls, ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) + prompt_tokens = self._num_tokens_from_messages( + messages=prompt_messages, tools=tools + ) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -573,7 +674,7 @@ class LocalAILarguageModel(LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -589,9 +690,9 @@ class LocalAILarguageModel(LargeLanguageModel): full_response += delta.delta.content - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -602,18 +703,15 @@ class LocalAILarguageModel(LargeLanguageModel): if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function + id=0, type="function", function=function ) tool_calls.append(tool_call) - return tool_calls + return tool_calls @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -635,15 +733,9 @@ class LocalAILarguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/localai/localai.py b/model-providers/model_providers/core/model_runtime/model_providers/localai/localai.py index c4ccd3bc..828944f4 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/localai/localai.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/localai/localai.py @@ -1,11 +1,12 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class LocalAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 14d82034..b42cfb3c 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -6,8 +6,17 @@ from requests import post from yarl import URL from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceType, +) +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -16,17 +25,26 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) class LocalAITextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,39 +55,38 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): :return: embeddings result """ if len(texts) != 1: - raise InvokeBadRequestError('Only one text is supported') + raise InvokeBadRequestError("Only one text is supported") - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model if not server_url: - raise CredentialsValidateFailedError('server_url is required') + raise CredentialsValidateFailedError("server_url is required") if not model_name: - raise CredentialsValidateFailedError('model_name is required') - - url = server_url - headers = { - 'Authorization': 'Bearer 123', - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("model_name is required") - data = { - 'model': model_name, - 'input': texts[0] - } + url = server_url + headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"} + + data = {"model": model_name, "input": texts[0]} try: - response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) + response = post( + str(URL(url) / "embeddings"), + headers=headers, + data=dumps(data), + timeout=10, + ) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - code = resp['error']['code'] - msg = resp['error']['message'] + code = resp["error"]["code"] + msg = resp["error"]["message"] if code == 500: raise InvokeServerUnavailableError(msg) - + if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -79,23 +96,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=usage["total_tokens"] + ) result = TextEmbeddingResult( model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], + usage=usage, ) return result @@ -114,8 +135,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): # use GPT2Tokenizer to get num tokens num_tokens += self._get_num_tokens_by_gpt2(text) return num_tokens - - def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + + def _get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity | None: """ Get customizable model schema @@ -130,10 +153,12 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): features=[], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), + ModelPropertyKey.CONTEXT_SIZE: int( + credentials.get("context_size", "512") + ), ModelPropertyKey.MAX_CHUNKS: 1, }, - parameter_rules=[] + parameter_rules=[], ) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -145,33 +170,25 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid credentials') + raise CredentialsValidateFailedError("Invalid credentials") except InvokeConnectionError as e: - raise CredentialsValidateFailedError(f'Invalid credentials: {e}') + raise CredentialsValidateFailedError(f"Invalid credentials: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -185,7 +202,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -196,7 +213,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 279f9d3b..7f2edebc 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -12,47 +12,61 @@ from model_providers.core.model_runtime.model_providers.minimax.llm.errors impor InvalidAuthenticationError, RateLimitReachedError, ) -from model_providers.core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage +from model_providers.core.model_runtime.model_providers.minimax.llm.types import ( + MinimaxMessage, +) class MinimaxChatCompletion: """ - Minimax Chat Completion API + Minimax Chat Completion API """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') - - url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' + raise InvalidAPIKeyError("Invalid API key or group ID") + + url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if ( + "max_tokens" in model_parameters + and type(model_parameters["max_tokens"]) == int + ): + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if ( + "temperature" in model_parameters + and type(model_parameters["temperature"]) == float + ): + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - prompt = '你是一个什么都懂的专家' + prompt = "你是一个什么都懂的专家" - role_meta = { - 'user_name': '我', - 'bot_name': '专家' - } + role_meta = {"user_name": "我", "bot_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') - + raise BadRequestError("At least one message is required") + if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: prompt = prompt_messages[0].content @@ -60,40 +74,48 @@ class MinimaxChatCompletion: # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') - - messages = [{ - 'sender_type': message.role, - 'text': message.content, - } for message in prompt_messages] + raise BadRequestError("At least one user message is required") + + messages = [ + { + "sender_type": message.role, + "text": message.content, + } + for message in prompt_messages + ] headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' + "Authorization": "Bearer " + api_key, + "Content-Type": "application/json", } body = { - 'model': model, - 'messages': messages, - 'prompt': prompt, - 'role_meta': role_meta, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "prompt": prompt, + "role_meta": role_meta, + "stream": stream, + **extra_kwargs, } try: response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + url=url, + data=dumps(body), + headers=headers, + stream=stream, + timeout=(10, 300), + ) except Exception as e: raise InternalServerError(e) - + if response.status_code != 200: raise InternalServerError(response.text) - + if stream: return self._handle_stream_chat_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001 or code == 1013 or code == 1027: raise InternalServerError(msg) @@ -110,65 +132,64 @@ class MinimaxChatCompletion: def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - + message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value + content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value ) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message - def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: + def _handle_stream_chat_generate_response( + self, response: Response + ) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) - if data['reply']: - total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' + if data["reply"]: + total_tokens = data["usage"]["total_tokens"] + message = MinimaxMessage( + role=MinimaxMessage.Role.ASSISTANT.value, content="" ) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens + "prompt_tokens": 0, + "completion_tokens": total_tokens, + "total_tokens": total_tokens, } - message.stop_reason = data['choices'][0]['finish_reason'] + message.stop_reason = data["choices"][0]["finish_reason"] yield message return - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['delta'] + message = choice["delta"] yield MinimaxMessage( - content=message, - role=MinimaxMessage.Role.ASSISTANT.value + content=message, role=MinimaxMessage.Role.ASSISTANT.value ) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index ee54baaa..cbe6979f 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -12,98 +12,115 @@ from model_providers.core.model_runtime.model_providers.minimax.llm.errors impor InvalidAuthenticationError, RateLimitReachedError, ) -from model_providers.core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage +from model_providers.core.model_runtime.model_providers.minimax.llm.types import ( + MinimaxMessage, +) class MinimaxChatCompletionPro: """ - Minimax Chat Completion Pro API, supports function calling - however, we do not have enough time and energy to implement it, but the parameters are reserved + Minimax Chat Completion Pro API, supports function calling + however, we do not have enough time and energy to implement it, but the parameters are reserved """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') - - url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}' + raise InvalidAPIKeyError("Invalid API key or group ID") + + url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if ( + "max_tokens" in model_parameters + and type(model_parameters["max_tokens"]) == int + ): + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if ( + "temperature" in model_parameters + and type(model_parameters["temperature"]) == float + ): + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] - - if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']: - extra_kwargs['plugins'] = [ - 'plugin_web_search' - ] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - bot_setting = { - 'bot_name': '专家', - 'content': '你是一个什么都懂的专家' - } + if ( + "plugin_web_search" in model_parameters + and model_parameters["plugin_web_search"] + ): + extra_kwargs["plugins"] = ["plugin_web_search"] - reply_constraints = { - 'sender_type': 'BOT', - 'sender_name': '专家' - } + bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"} + + reply_constraints = {"sender_type": "BOT", "sender_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') - + raise BadRequestError("At least one message is required") + if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: - bot_setting['content'] = prompt_messages[0].content + bot_setting["content"] = prompt_messages[0].content prompt_messages = prompt_messages[1:] # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') - + raise BadRequestError("At least one user message is required") + messages = [message.to_dict() for message in prompt_messages] headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' + "Authorization": "Bearer " + api_key, + "Content-Type": "application/json", } body = { - 'model': model, - 'messages': messages, - 'bot_setting': [bot_setting], - 'reply_constraints': reply_constraints, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "bot_setting": [bot_setting], + "reply_constraints": reply_constraints, + "stream": stream, + **extra_kwargs, } if tools: - body['functions'] = tools - body['function_call'] = { 'type': 'auto' } + body["functions"] = tools + body["function_call"] = {"type": "auto"} try: response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + url=url, + data=dumps(body), + headers=headers, + stream=stream, + timeout=(10, 300), + ) except Exception as e: raise InternalServerError(e) - + if response.status_code != 200: raise InternalServerError(response.text) - + if stream: return self._handle_stream_chat_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001 or code == 1013 or code == 1027: raise InternalServerError(msg) @@ -120,92 +137,101 @@ class MinimaxChatCompletionPro: def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - + message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value + content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value ) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message - def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: + def _handle_stream_chat_generate_response( + self, response: Response + ) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ function_call_storage = None for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) - if data['reply'] or 'usage' in data and data['usage']: - total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' + if data["reply"] or "usage" in data and data["usage"]: + total_tokens = data["usage"]["total_tokens"] + message = MinimaxMessage( + role=MinimaxMessage.Role.ASSISTANT.value, content="" ) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens + "prompt_tokens": 0, + "completion_tokens": total_tokens, + "total_tokens": total_tokens, } - message.stop_reason = data['choices'][0]['finish_reason'] + message.stop_reason = data["choices"][0]["finish_reason"] if function_call_storage: - function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) + function_call_message = MinimaxMessage( + content="", role=MinimaxMessage.Role.ASSISTANT.value + ) function_call_message.function_call = function_call_storage yield function_call_message yield message return - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] - if 'function_call' in message: + if "function_call" in message: if not function_call_storage: - function_call_storage = message['function_call'] - if 'arguments' not in function_call_storage or not function_call_storage['arguments']: - function_call_storage['arguments'] = '' + function_call_storage = message["function_call"] + if ( + "arguments" not in function_call_storage + or not function_call_storage["arguments"] + ): + function_call_storage["arguments"] = "" continue else: - function_call_storage['arguments'] += message['function_call']['arguments'] + function_call_storage["arguments"] += message["function_call"][ + "arguments" + ] continue else: if function_call_storage: - message['function_call'] = function_call_storage + message["function_call"] = function_call_storage function_call_storage = None - - minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) - if 'function_call' in message: - minimax_message.function_call = message['function_call'] + minimax_message = MinimaxMessage( + content="", role=MinimaxMessage.Role.ASSISTANT.value + ) - if 'text' in message: - minimax_message.content = message['text'] + if "function_call" in message: + minimax_message.function_call = message["function_call"] + + if "text" in message: + minimax_message.content = message["text"] yield minimax_message diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/errors.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/errors.py index d9d279e6..309b5cf4 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/errors.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/llm.py index d827b49e..1696bc7a 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/llm.py @@ -1,6 +1,10 @@ from collections.abc import Generator -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -17,10 +21,18 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.minimax.llm.chat_completion import MinimaxChatCompletion -from model_providers.core.model_runtime.model_providers.minimax.llm.chat_completion_pro import MinimaxChatCompletionPro +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.minimax.llm.chat_completion import ( + MinimaxChatCompletion, +) +from model_providers.core.model_runtime.model_providers.minimax.llm.chat_completion_pro import ( + MinimaxChatCompletionPro, +) from model_providers.core.model_runtime.model_providers.minimax.llm.errors import ( BadRequestError, InsufficientAccountBalanceError, @@ -29,131 +41,202 @@ from model_providers.core.model_runtime.model_providers.minimax.llm.errors impor InvalidAuthenticationError, RateLimitReachedError, ) -from model_providers.core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage +from model_providers.core.model_runtime.model_providers.minimax.llm.types import ( + MinimaxMessage, +) class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { - 'abab6-chat': MinimaxChatCompletionPro, - 'abab5.5s-chat': MinimaxChatCompletionPro, - 'abab5.5-chat': MinimaxChatCompletionPro, - 'abab5-chat': MinimaxChatCompletion + "abab6-chat": MinimaxChatCompletionPro, + "abab5.5s-chat": MinimaxChatCompletionPro, + "abab5.5-chat": MinimaxChatCompletionPro, + "abab5-chat": MinimaxChatCompletion, } - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials for Baichuan model """ if model not in self.model_apis: - raise CredentialsValidateFailedError(f'Invalid model: {model}') + raise CredentialsValidateFailedError(f"Invalid model: {model}") - if not credentials.get('minimax_api_key'): - raise CredentialsValidateFailedError('Invalid API key') + if not credentials.get("minimax_api_key"): + raise CredentialsValidateFailedError("Invalid API key") + + if not credentials.get("minimax_group_id"): + raise CredentialsValidateFailedError("Invalid group ID") - if not credentials.get('minimax_group_id'): - raise CredentialsValidateFailedError('Invalid group ID') - # ping instance = MinimaxChatCompletionPro() try: instance.generate( - model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'], - prompt_messages=[ - MinimaxMessage(content='ping', role='USER') - ], + model=model, + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], + prompt_messages=[MinimaxMessage(content="ping", role="USER")], model_parameters={}, - tools=[], stop=[], + tools=[], + stop=[], stream=False, - user='' + user="", ) except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool] + ) -> int: """ - Calculate num tokens for minimax model + Calculate num tokens for minimax model - not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way - to caculate the num tokens, so we use str() to convert the prompt to string + not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way + to calculate the num tokens, so we use str() to convert the prompt to string - Minimax does not provide their own tokenizer of adab5.5 and abab5 model - therefore, we use gpt2 tokenizer instead + Minimax does not provide their own tokenizer of adab5.5 and abab5 model + therefore, we use gpt2 tokenizer instead """ - messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages] + messages_dict = [ + self._convert_prompt_message_to_minimax_message(m).to_dict() + for m in messages + ] return self._get_num_tokens_by_gpt2(str(messages_dict)) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface + use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface """ client: MinimaxChatCompletionPro = self.model_apis[model]() if tools: - tools = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + tools = [ + { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in tools + ] response = client.generate( model=model, - api_key=credentials['minimax_api_key'], - group_id=credentials['minimax_group_id'], - prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages], + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], + prompt_messages=[ + self._convert_prompt_message_to_minimax_message(message) + for message in prompt_messages + ], model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, + prompt_messages=prompt_messages, + credentials=credentials, + response=response, + ) + return self._handle_chat_generate_response( + model=model, + prompt_messages=prompt_messages, + credentials=credentials, + response=response, + ) - def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage: + def _convert_prompt_message_to_minimax_message( + self, prompt_message: PromptMessage + ) -> MinimaxMessage: """ - convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface + convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface """ if isinstance(prompt_message, SystemPromptMessage): - return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content) + return MinimaxMessage( + role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content + ) elif isinstance(prompt_message, UserPromptMessage): - return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) + return MinimaxMessage( + role=MinimaxMessage.Role.USER.value, content=prompt_message.content + ) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' + role=MinimaxMessage.Role.ASSISTANT.value, content="" ) - message.function_call={ - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments + message.function_call = { + "name": prompt_message.tool_calls[0].function.name, + "arguments": prompt_message.tool_calls[0].function.arguments, } return message - return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) + return MinimaxMessage( + role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content + ) elif isinstance(prompt_message, ToolPromptMessage): - return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) + return MinimaxMessage( + role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content + ) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError( + f"Prompt message type {type(prompt_message)} is not supported" + ) - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: MinimaxMessage, + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -164,15 +247,20 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[MinimaxMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[MinimaxMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, @@ -180,15 +268,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=message.content, tool_calls=[] ), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason + if message.stop_reason + else None, ), ) elif message.function_call: - if 'name' not in message.function_call or 'arguments' not in message.function_call: + if ( + "name" not in message.function_call + or "arguments" not in message.function_call + ): continue yield LLMResultChunk( @@ -197,15 +289,17 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content='', - tool_calls=[AssistantPromptMessage.ToolCall( - id='', - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.function_call['name'], - arguments=message.function_call['arguments'] + content="", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=message.function_call["name"], + arguments=message.function_call["arguments"], + ), ) - )] + ], ), ), ) @@ -216,10 +310,11 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=message.content, tool_calls=[] ), - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason + if message.stop_reason + else None, ), ) @@ -234,22 +329,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/types.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/types.py index b33a7ca9..5e9d73dd 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/types.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/types.py @@ -4,32 +4,32 @@ from typing import Any class MinimaxMessage: class Role(Enum): - USER = 'USER' - ASSISTANT = 'BOT' - SYSTEM = 'SYSTEM' - FUNCTION = 'FUNCTION' + USER = "USER" + ASSISTANT = "BOT" + SYSTEM = "SYSTEM" + FUNCTION = "FUNCTION" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" function_call: dict[str, Any] = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: return { - 'sender_type': 'BOT', - 'sender_name': '专家', - 'text': '', - 'function_call': self.function_call + "sender_type": "BOT", + "sender_name": "专家", + "text": "", + "function_call": self.function_call, } - + return { - 'sender_type': self.role, - 'sender_name': '我' if self.role == 'USER' else '专家', - 'text': self.content, + "sender_type": self.role, + "sender_name": "我" if self.role == "USER" else "专家", + "text": self.content, } - - def __init__(self, content: str, role: str = 'USER') -> None: + + def __init__(self, content: str, role: str = "USER") -> None: self.content = content - self.role = role \ No newline at end of file + self.role = role diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/minimax.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/minimax.py index d85b2293..e0f0bd36 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/minimax.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/minimax.py @@ -1,11 +1,16 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) + class MinimaxProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -20,11 +25,12 @@ class MinimaxProvider(ModelProvider): # Use `abab5.5-chat` model for validate, model_instance.validate_credentials( - model='abab5.5-chat', - credentials=credentials + model="abab5.5-chat", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise CredentialsValidateFailedError(f'{ex}') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) + raise CredentialsValidateFailedError(f"{ex}") diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 9c065b37..ff764bb9 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -5,7 +5,10 @@ from typing import Optional from requests import post from model_providers.core.model_runtime.entities.model_entities import PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -14,8 +17,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) from model_providers.core.model_runtime.model_providers.minimax.llm.errors import ( BadRequestError, InsufficientAccountBalanceError, @@ -30,11 +37,16 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ Model class for Minimax text embedding model. """ - api_base: str = 'https://api.minimax.chat/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.minimax.chat/v1/embeddings" + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -44,55 +56,51 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - api_key = credentials['minimax_api_key'] - group_id = credentials['minimax_group_id'] - if model != 'embo-01': - raise ValueError('Invalid model name') + api_key = credentials["minimax_api_key"] + group_id = credentials["minimax_group_id"] + if model != "embo-01": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - url = f'{self.api_base}?GroupId={group_id}' + raise CredentialsValidateFailedError("api_key is required") + url = f"{self.api_base}?GroupId={group_id}" headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' + "Authorization": "Bearer " + api_key, + "Content-Type": "application/json", } - data = { - 'model': 'embo-01', - 'texts': texts, - 'type': 'db' - } + data = {"model": "embo-01", "texts": texts, "type": "db"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json() # check if there is an error - if resp['base_resp']['status_code'] != 0: - code = resp['base_resp']['status_code'] - msg = resp['base_resp']['status_msg'] + if resp["base_resp"]["status_code"] != 0: + code = resp["base_resp"]["status_code"] + msg = resp["base_resp"]["status_msg"] self._handle_error(code, msg) - embeddings = resp['vectors'] - total_tokens = resp['total_tokens'] + embeddings = resp["vectors"] + total_tokens = resp["total_tokens"] except InvalidAuthenticationError: - raise InvalidAPIKeyError('Invalid api key') + raise InvalidAPIKeyError("Invalid api key") except KeyError as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InternalServerError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=total_tokens ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) + return result def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -119,9 +127,9 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): if code == 1000 or code == 1001: @@ -148,26 +156,20 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -181,7 +183,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -192,7 +194,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/mistralai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/mistralai/llm/llm.py index 224db4c4..364d4a92 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/mistralai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/mistralai/llm/llm.py @@ -2,24 +2,43 @@ from collections.abc import Generator from typing import Optional, Union from model_providers.core.model_runtime.entities.llm_entities import LLMResult -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( + OAIAPICompatLargeLanguageModel, +) class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - + # mistral dose not support user/stop arguments stop = [] user = None - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._invoke( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) @@ -27,5 +46,5 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.mistral.ai/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.mistral.ai/v1" diff --git a/model-providers/model_providers/core/model_runtime/model_providers/mistralai/mistralai.py b/model-providers/model_providers/core/model_runtime/model_providers/mistralai/mistralai.py index 239556be..aefce058 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/mistralai/mistralai.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/mistralai/mistralai.py @@ -1,14 +1,17 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -20,11 +23,12 @@ class MistralAIProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) model_instance.validate_credentials( - model='open-mistral-7b', - credentials=credentials + model="open-mistral-7b", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py index ed46dc30..bb3c4c91 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py @@ -6,11 +6,24 @@ from typing import Optional from pydantic import BaseModel from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider -from model_providers.core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from model_providers.core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator -from model_providers.core.utils.position_helper import get_position_map, sort_to_dict_by_position_map +from model_providers.core.model_runtime.entities.provider_entities import ( + ProviderConfig, + ProviderEntity, + SimpleProviderEntity, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) +from model_providers.core.model_runtime.schema_validators.model_credential_schema_validator import ( + ModelCredentialSchemaValidator, +) +from model_providers.core.model_runtime.schema_validators.provider_credential_schema_validator import ( + ProviderCredentialSchemaValidator, +) +from model_providers.core.utils.position_helper import ( + get_position_map, + sort_to_dict_by_position_map, +) logger = logging.getLogger(__name__) @@ -91,8 +104,9 @@ class ModelProviderFactory: return filtered_credentials - def model_credentials_validate(self, provider: str, model_type: ModelType, - model: str, credentials: dict) -> dict: + def model_credentials_validate( + self, provider: str, model_type: ModelType, model: str, credentials: dict + ) -> dict: """ Validate model credentials @@ -123,11 +137,12 @@ class ModelProviderFactory: return filtered_credentials - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - provider_configs: Optional[list[ProviderConfig]] = None) \ - -> list[SimpleProviderEntity]: + def get_models( + self, + provider: Optional[str] = None, + model_type: Optional[ModelType] = None, + provider_configs: Optional[list[ProviderConfig]] = None, + ) -> list[SimpleProviderEntity]: """ Get all models for given model type @@ -142,7 +157,9 @@ class ModelProviderFactory: # convert provider_configs to dict provider_credentials_dict = {} for provider_config in provider_configs: - provider_credentials_dict[provider_config.provider] = provider_config.credentials + provider_credentials_dict[ + provider_config.provider + ] = provider_config.credentials # traverse all model_provider_extensions providers = [] @@ -192,7 +209,7 @@ class ModelProviderFactory: # get the provider extension model_provider_extension = model_provider_extensions.get(provider) if not model_provider_extension: - raise Exception(f'Invalid provider: {provider}') + raise Exception(f"Invalid provider: {provider}") # get the provider instance model_provider_instance = model_provider_extension.provider_instance @@ -203,7 +220,6 @@ class ModelProviderFactory: if self.model_provider_extensions: return self.model_provider_extensions - # get the path of current classes current_path = os.path.abspath(__file__) model_providers_path = os.path.dirname(current_path) @@ -212,8 +228,8 @@ class ModelProviderFactory: model_provider_dir_paths = [ os.path.join(model_providers_path, model_provider_dir) for model_provider_dir in os.listdir(model_providers_path) - if not model_provider_dir.startswith('__') - and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) + if not model_provider_dir.startswith("__") + and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) ] # get _position.yaml file path @@ -227,37 +243,54 @@ class ModelProviderFactory: file_names = os.listdir(model_provider_dir_path) - if (model_provider_name + '.py') not in file_names: - logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.") + if (model_provider_name + ".py") not in file_names: + logger.warning( + f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip." + ) continue # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider - py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py') - spec = importlib.util.spec_from_file_location(f'model_providers.core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path) + py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py") + spec = importlib.util.spec_from_file_location( + f"model_providers.core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}", + py_path, + ) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) model_provider_class = None for name, obj in vars(mod).items(): - if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider: + if ( + isinstance(obj, type) + and issubclass(obj, ModelProvider) + and obj != ModelProvider + ): model_provider_class = obj break if not model_provider_class: - logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") + logger.warning( + f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip." + ) continue - if f'{model_provider_name}.yaml' not in file_names: - logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") + if f"{model_provider_name}.yaml" not in file_names: + logger.warning( + f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip." + ) continue - model_providers.append(ModelProviderExtension( - name=model_provider_name, - provider_instance=model_provider_class(), - position=position_map.get(model_provider_name) - )) + model_providers.append( + ModelProviderExtension( + name=model_provider_name, + provider_instance=model_provider_class(), + position=position_map.get(model_provider_name), + ) + ) - sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) + sorted_extensions = sort_to_dict_by_position_map( + position_map, model_providers, lambda x: x.name + ) self.model_provider_extensions = sorted_extensions diff --git a/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py index f4fae7f6..fbf4bf29 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -2,19 +2,39 @@ from collections.abc import Generator from typing import Optional, Union from model_providers.core.model_runtime.entities.llm_entities import LLMResult -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( + OAIAPICompatLargeLanguageModel, +) class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) user = user[:32] if user else None - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._invoke( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) @@ -22,5 +42,5 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.moonshot.cn/v1" diff --git a/model-providers/model_providers/core/model_runtime/model_providers/moonshot/moonshot.py b/model-providers/model_providers/core/model_runtime/model_providers/moonshot/moonshot.py index d8369bba..0e40fac0 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/moonshot/moonshot.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/moonshot/moonshot.py @@ -1,14 +1,17 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class MoonshotProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -20,11 +23,12 @@ class MoonshotProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) model_instance.validate_credentials( - model='moonshot-v1-8k', - credentials=credentials + model="moonshot-v1-8k", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py index 8132afad..19a7cc58 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py @@ -8,7 +8,12 @@ from urllib.parse import urljoin import requests -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -39,8 +44,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) logger = logging.getLogger(__name__) @@ -50,11 +59,17 @@ class OllamaLargeLanguageModel(LargeLanguageModel): Model class for Ollama large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -75,11 +90,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -100,10 +120,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if isinstance(first_prompt_message.content, str): text = first_prompt_message.content else: - text = '' + text = "" for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) text = message_content.data break return self._get_num_tokens_by_gpt2(text) @@ -121,19 +143,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel): model=model, credentials=credentials, prompt_messages=[UserPromptMessage(content="ping")], - model_parameters={ - 'num_predict': 5 - }, - stream=False + model_parameters={"num_predict": 5}, + stream=False, ) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + raise CredentialsValidateFailedError( + f"An error occurred during credentials validation: {ex.description}" + ) except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError( + f"An error occurred during credentials validation: {str(ex)}" + ) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -146,76 +177,89 @@ class OllamaLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - endpoint_url = credentials['base_url'] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials["base_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'stream': stream - } + data = {"model": model, "stream": stream} - if 'format' in model_parameters: - data['format'] = model_parameters['format'] - del model_parameters['format'] + if "format" in model_parameters: + data["format"] = model_parameters["format"] + del model_parameters["format"] - data['options'] = model_parameters or {} + data["options"] = model_parameters or {} if stop: - data['stop'] = "\n".join(stop) + data["stop"] = "\n".join(stop) - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, 'api/chat') - data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] + endpoint_url = urljoin(endpoint_url, "api/chat") + data["messages"] = [ + self._convert_prompt_message_to_dict(m) for m in prompt_messages + ] else: - endpoint_url = urljoin(endpoint_url, 'api/generate') + endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] if isinstance(first_prompt_message, UserPromptMessage): first_prompt_message = cast(UserPromptMessage, first_prompt_message) if isinstance(first_prompt_message.content, str): - data['prompt'] = first_prompt_message.content + data["prompt"] = first_prompt_message.content else: - text = '' + text = "" images = [] for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + message_content = cast( + ImagePromptMessageContent, message_content + ) + image_data = re.sub( + r"^data:image\/[a-zA-Z]+;base64,", + "", + message_content.data, + ) images.append(image_data) - data['prompt'] = text - data['images'] = images + data["prompt"] = text + data["images"] = images # send a post request to validate the credentials response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 60), - stream=stream + endpoint_url, headers=headers, json=data, timeout=(10, 60), stream=stream ) response.encoding = "utf-8" if response.status_code != 200: - raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") + raise InvokeError( + f"API request failed with status code {response.status_code}: {response.text}" + ) if stream: - return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, completion_type, response, prompt_messages + ) - return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) + return self._handle_generate_response( + model, credentials, completion_type, response, prompt_messages + ) - def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode, - response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + completion_type: LLMMode, + response: requests.Response, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm completion response @@ -229,14 +273,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel): response_json = response.json() if completion_type is LLMMode.CHAT: - message = response_json.get('message', {}) - response_content = message.get('content', '') + message = response_json.get("message", {}) + response_content = message.get("content", "") else: - response_content = response_json['response'] + response_content = response_json["response"] assistant_message = AssistantPromptMessage(content=response_content) - if 'prompt_eval_count' in response_json and 'eval_count' in response_json: + if "prompt_eval_count" in response_json and "eval_count" in response_json: # transform usage prompt_tokens = response_json["prompt_eval_count"] completion_tokens = response_json["eval_count"] @@ -246,7 +290,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response result = LLMResult( @@ -258,8 +304,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode, - response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + completion_type: LLMMode, + response: requests.Response, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm completion stream response @@ -270,17 +322,20 @@ class OllamaLargeLanguageModel(LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) return LLMResultChunk( model=model, @@ -289,11 +344,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel): index=index, message=message, finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) - for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'): + for chunk in response.iter_lines(decode_unicode=True, delimiter="\n"): if not chunk: continue @@ -304,7 +359,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) chunk_index += 1 @@ -314,55 +369,57 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if not chunk_json: continue - if 'message' not in chunk_json: - text = '' + if "message" not in chunk_json: + text = "" else: - text = chunk_json.get('message').get('content', '') + text = chunk_json.get("message").get("content", "") else: if not chunk_json: continue # transform assistant message to prompt message - text = chunk_json['response'] + text = chunk_json["response"] - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text - if chunk_json['done']: + if chunk_json["done"]: # calculate num tokens - if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json: + if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json: # transform usage prompt_tokens = chunk_json["prompt_eval_count"] completion_tokens = chunk_json["eval_count"] else: # calculate num tokens - prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content) + prompt_tokens = self._get_num_tokens_by_gpt2( + prompt_messages[0].content + ) completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( - model=chunk_json['model'], + model=chunk_json["model"], prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - finish_reason='stop', - usage=usage - ) + finish_reason="stop", + usage=usage, + ), ) else: yield LLMResultChunk( - model=chunk_json['model'], + model=chunk_json["model"], prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 @@ -376,15 +433,21 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if isinstance(message.content, str): message_dict = {"role": "user", "content": message.content} else: - text = '' + text = "" images = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + message_content = cast( + ImagePromptMessageContent, message_content + ) + image_data = re.sub( + r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data + ) images.append(image_data) message_dict = {"role": "user", "content": text, "images": images} @@ -414,7 +477,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return num_tokens - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: """ Get customizable model schema. @@ -425,20 +490,19 @@ class OllamaLargeLanguageModel(LargeLanguageModel): """ extras = {} - if 'vision_support' in credentials and credentials['vision_support'] == 'true': - extras['features'] = [ModelFeature.VISION] + if "vision_support" in credentials and credentials["vision_support"] == "true": + extras["features"] = [ModelFeature.VISION] entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.MODE: credentials.get('mode'), - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.MODE: credentials.get("mode"), + ModelPropertyKey.CONTEXT_SIZE: int( + credentials.get("context_size", 4096) + ), }, parameter_rules=[ ParameterRule( @@ -446,161 +510,191 @@ class OllamaLargeLanguageModel(LargeLanguageModel): use_template=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - help=I18nObject(en_US="The temperature of the model. " - "Increasing the temperature will make the model answer " - "more creatively. (Default: 0.8)"), + help=I18nObject( + en_US="The temperature of the model. " + "Increasing the temperature will make the model answer " + "more creatively. (Default: 0.8)" + ), default=0.8, min=0, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, use_template=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " - "more diverse text, while a lower value (e.g., 0.5) will generate more " - "focused and conservative text. (Default: 0.9)"), + help=I18nObject( + en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " + "more diverse text, while a lower value (e.g., 0.5) will generate more " + "focused and conservative text. (Default: 0.9)" + ), default=0.9, min=0, - max=1 + max=1, ), ParameterRule( name="top_k", label=I18nObject(en_US="Top K"), type=ParameterType.INT, - help=I18nObject(en_US="Reduces the probability of generating nonsense. " - "A higher value (e.g. 100) will give more diverse answers, " - "while a lower value (e.g. 10) will be more conservative. (Default: 40)"), + help=I18nObject( + en_US="Reduces the probability of generating nonsense. " + "A higher value (e.g. 100) will give more diverse answers, " + "while a lower value (e.g. 10) will be more conservative. (Default: 40)" + ), default=40, min=1, - max=100 + max=100, ), ParameterRule( - name='repeat_penalty', + name="repeat_penalty", label=I18nObject(en_US="Repeat Penalty"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Sets how strongly to penalize repetitions. " - "A higher value (e.g., 1.5) will penalize repetitions more strongly, " - "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"), + help=I18nObject( + en_US="Sets how strongly to penalize repetitions. " + "A higher value (e.g., 1.5) will penalize repetitions more strongly, " + "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)" + ), default=1.1, min=-2, - max=2 + max=2, ), ParameterRule( - name='num_predict', - use_template='max_tokens', + name="num_predict", + use_template="max_tokens", label=I18nObject(en_US="Num Predict"), type=ParameterType.INT, - help=I18nObject(en_US="Maximum number of tokens to predict when generating text. " - "(Default: 128, -1 = infinite generation, -2 = fill context)"), + help=I18nObject( + en_US="Maximum number of tokens to predict when generating text. " + "(Default: 128, -1 = infinite generation, -2 = fill context)" + ), default=128, min=-2, - max=int(credentials.get('max_tokens', 4096)), + max=int(credentials.get("max_tokens", 4096)), ), ParameterRule( - name='mirostat', + name="mirostat", label=I18nObject(en_US="Mirostat sampling"), type=ParameterType.INT, - help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. " - "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"), + help=I18nObject( + en_US="Enable Mirostat sampling for controlling perplexity. " + "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)" + ), default=0, min=0, - max=2 + max=2, ), ParameterRule( - name='mirostat_eta', + name="mirostat_eta", label=I18nObject(en_US="Mirostat Eta"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from " - "the generated text. A lower learning rate will result in slower adjustments, " - "while a higher learning rate will make the algorithm more responsive. " - "(Default: 0.1)"), + help=I18nObject( + en_US="Influences how quickly the algorithm responds to feedback from " + "the generated text. A lower learning rate will result in slower adjustments, " + "while a higher learning rate will make the algorithm more responsive. " + "(Default: 0.1)" + ), default=0.1, - precision=1 + precision=1, ), ParameterRule( - name='mirostat_tau', + name="mirostat_tau", label=I18nObject(en_US="Mirostat Tau"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. " - "A lower value will result in more focused and coherent text. (Default: 5.0)"), + help=I18nObject( + en_US="Controls the balance between coherence and diversity of the output. " + "A lower value will result in more focused and coherent text. (Default: 5.0)" + ), default=5.0, - precision=1 + precision=1, ), ParameterRule( - name='num_ctx', + name="num_ctx", label=I18nObject(en_US="Size of context window"), type=ParameterType.INT, - help=I18nObject(en_US="Sets the size of the context window used to generate the next token. " - "(Default: 2048)"), + help=I18nObject( + en_US="Sets the size of the context window used to generate the next token. " + "(Default: 2048)" + ), default=2048, - min=1 - ), - ParameterRule( - name='num_gpu', - label=I18nObject(en_US="Num GPU"), - type=ParameterType.INT, - help=I18nObject(en_US="The number of layers to send to the GPU(s). " - "On macOS it defaults to 1 to enable metal support, 0 to disable."), - default=1, - min=0, - max=1 - ), - ParameterRule( - name='num_thread', - label=I18nObject(en_US="Num Thread"), - type=ParameterType.INT, - help=I18nObject(en_US="Sets the number of threads to use during computation. " - "By default, Ollama will detect this for optimal performance. " - "It is recommended to set this value to the number of physical CPU cores " - "your system has (as opposed to the logical number of cores)."), min=1, ), ParameterRule( - name='repeat_last_n', + name="num_gpu", + label=I18nObject(en_US="Num GPU"), + type=ParameterType.INT, + help=I18nObject( + en_US="The number of layers to send to the GPU(s). " + "On macOS it defaults to 1 to enable metal support, 0 to disable." + ), + default=1, + min=0, + max=1, + ), + ParameterRule( + name="num_thread", + label=I18nObject(en_US="Num Thread"), + type=ParameterType.INT, + help=I18nObject( + en_US="Sets the number of threads to use during computation. " + "By default, Ollama will detect this for optimal performance. " + "It is recommended to set this value to the number of physical CPU cores " + "your system has (as opposed to the logical number of cores)." + ), + min=1, + ), + ParameterRule( + name="repeat_last_n", label=I18nObject(en_US="Repeat last N"), type=ParameterType.INT, - help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. " - "(Default: 64, 0 = disabled, -1 = num_ctx)"), + help=I18nObject( + en_US="Sets how far back for the model to look back to prevent repetition. " + "(Default: 64, 0 = disabled, -1 = num_ctx)" + ), default=64, - min=-1 + min=-1, ), ParameterRule( - name='tfs_z', + name="tfs_z", label=I18nObject(en_US="TFS Z"), type=ParameterType.FLOAT, - help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens " - "from the output. A higher value (e.g., 2.0) will reduce the impact more, " - "while a value of 1.0 disables this setting. (default: 1)"), + help=I18nObject( + en_US="Tail free sampling is used to reduce the impact of less probable tokens " + "from the output. A higher value (e.g., 2.0) will reduce the impact more, " + "while a value of 1.0 disables this setting. (default: 1)" + ), default=1, - precision=1 + precision=1, ), ParameterRule( - name='seed', + name="seed", label=I18nObject(en_US="Seed"), type=ParameterType.INT, - help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to " - "a specific number will make the model generate the same text for " - "the same prompt. (Default: 0)"), - default=0 + help=I18nObject( + en_US="Sets the random number seed to use for generation. Setting this to " + "a specific number will make the model generate the same text for " + "the same prompt. (Default: 0)" + ), + default=0, ), ParameterRule( - name='format', + name="format", label=I18nObject(en_US="Format"), type=ParameterType.STRING, - help=I18nObject(en_US="the format to return a response in." - " Currently the only accepted value is json."), - options=['json'], - ) + help=I18nObject( + en_US="the format to return a response in." + " Currently the only accepted value is json." + ), + options=["json"], + ), ], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), ), - **extras + **extras, ) return entity @@ -628,10 +722,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel): ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.py index 9701248e..c6a78011 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/ollama.py @@ -1,12 +1,13 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 3376a081..fd037190 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -17,7 +17,10 @@ from model_providers.core.model_runtime.entities.model_entities import ( PriceConfig, PriceType, ) -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -26,8 +29,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) logger = logging.getLogger(__name__) @@ -37,9 +44,13 @@ class OllamaEmbeddingModel(TextEmbeddingModel): Model class for an Ollama text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -51,15 +62,13 @@ class OllamaEmbeddingModel(TextEmbeddingModel): """ # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - endpoint_url = credentials.get('base_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("base_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'api/embeddings') + endpoint_url = urljoin(endpoint_url, "api/embeddings") # get model properties context_size = self._get_context_size(model, credentials) @@ -74,7 +83,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(len(text) * (np.floor(context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) @@ -83,8 +92,8 @@ class OllamaEmbeddingModel(TextEmbeddingModel): for text in inputs: # Prepare the payload for the request payload = { - 'prompt': text, - 'model': model, + "prompt": text, + "model": model, } # Make the request to the OpenAI API @@ -92,14 +101,14 @@ class OllamaEmbeddingModel(TextEmbeddingModel): endpoint_url, headers=headers, data=json.dumps(payload), - timeout=(10, 300) + timeout=(10, 300), ) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings = response_data['embedding'] + embeddings = response_data["embedding"] embedding_used_tokens = self.get_num_tokens(model, credentials, [text]) used_tokens += embedding_used_tokens @@ -107,15 +116,11 @@ class OllamaEmbeddingModel(TextEmbeddingModel): # calc usage usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens + model=model, credentials=credentials, tokens=used_tokens ) return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model + embeddings=batched_embeddings, usage=usage, model=model ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -138,19 +143,21 @@ class OllamaEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + raise CredentialsValidateFailedError( + f"An error occurred during credentials validation: {ex.description}" + ) except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError( + f"An error occurred during credentials validation: {str(ex)}" + ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -158,20 +165,22 @@ class OllamaEmbeddingModel(TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -185,7 +194,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -196,7 +205,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -224,10 +233,10 @@ class OllamaEmbeddingModel(TextEmbeddingModel): ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/_common.py index 81676b07..c459c20b 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/_common.py @@ -20,17 +20,17 @@ class _CommonOpenAI: :return: """ credentials_kwargs = { - "api_key": credentials['openai_api_key'], + "api_key": credentials["openai_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } - if 'openai_api_base' in credentials and credentials['openai_api_base']: - credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/') - credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1' + if "openai_api_base" in credentials and credentials["openai_api_base"]: + credentials["openai_api_base"] = credentials["openai_api_base"].rstrip("/") + credentials_kwargs["base_url"] = credentials["openai_api_base"] + "/v1" - if 'openai_organization' in credentials: - credentials_kwargs['organization'] = credentials['openai_organization'] + if "openai_organization" in credentials: + credentials_kwargs["organization"] = credentials["openai_organization"] return credentials_kwargs @@ -45,24 +45,17 @@ class _CommonOpenAI: :return: Invoke error mapping """ return { - InvokeConnectionError: [ - openai.APIConnectionError, - openai.APITimeoutError - ], - InvokeServerUnavailableError: [ - openai.InternalServerError - ], - InvokeRateLimitError: [ - openai.RateLimitError - ], + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], InvokeAuthorizationError: [ openai.AuthenticationError, - openai.PermissionDeniedError + openai.PermissionDeniedError, ], InvokeBadRequestError: [ openai.BadRequestError, openai.NotFoundError, openai.UnprocessableEntityError, - openai.APIError - ] + openai.APIError, + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py index 4fd03630..0b886642 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py @@ -5,12 +5,24 @@ from typing import Optional, Union, cast import tiktoken from openai import OpenAI, Stream from openai.types import Completion -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, +) from openai.types.chat.chat_completion_message import FunctionCall from model_providers.core.model_runtime.callbacks.base_callback import Callback -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -22,10 +34,22 @@ from model_providers.core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelType, + PriceConfig, +) +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.openai._common import ( + _CommonOpenAI, +) logger = logging.getLogger(__name__) @@ -38,16 +62,23 @@ if you are not sure about the structure. """ + class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -63,8 +94,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -79,7 +110,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -90,26 +121,36 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) # transform response format - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters[ + "response_format" + ] in ["JSON", "XML"]: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model @@ -122,7 +163,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) else: self._transform_completion_json_prompts( @@ -134,9 +175,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -146,14 +187,21 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -163,28 +211,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): stop.append("\n```") # check if there is a system message - if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + if len(prompt_messages) > 0 and isinstance( + prompt_messages[0], SystemPromptMessage + ): # override the system message prompt_messages[0] = SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", prompt_messages[0].content + ).replace("{{block}}", response_format) + ) + prompt_messages.append( + AssistantPromptMessage(content=f"\n```{response_format}\n") ) - prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) - prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - - def _transform_completion_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + prompt_messages.insert( + 0, + SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", + f"Please output a valid {response_format} object.", + ).replace("{{block}}", response_format) + ), + ) + prompt_messages.append( + AssistantPromptMessage(content=f"\n```{response_format}") + ) + + def _transform_completion_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -201,25 +266,30 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): break if user_message: - if prompt_messages[i].content[-11:] == 'Assistant: ': + if prompt_messages[i].content[-11:] == "Assistant: ": # now we are in the chat app, remove the last assistant message prompt_messages[i].content = prompt_messages[i].content[:-11] prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", user_message.content + ).replace("{{block}}", response_format) ) prompt_messages[i].content += f"Assistant:\n```{response_format}\n" else: prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", user_message.content + ).replace("{{block}}", response_format) ) prompt_messages[i].content += f"\n```{response_format}\n" - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -230,8 +300,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :return: """ # handle fine tune remote models - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] else: base_model = model @@ -261,14 +331,16 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # handle fine tune remote models base_model = model # fine-tuned model name likes ft:gpt-3.5-turbo-0613:personal::xxxxx - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # check if model exists remote_models = self.remote_models(credentials) remote_model_map = {model.model: model for model in remote_models} if model not in remote_model_map: - raise CredentialsValidateFailedError(f'Fine-tuned model {model} not found') + raise CredentialsValidateFailedError( + f"Fine-tuned model {model} not found" + ) # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -276,7 +348,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if model_mode == LLMMode.CHAT: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -285,7 +357,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -312,14 +384,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # get all remote models remote_models = client.models.list() - fine_tune_models = [model for model in remote_models if model.id.startswith('ft:')] + fine_tune_models = [ + model for model in remote_models if model.id.startswith("ft:") + ] ai_model_entities = [] for model in fine_tune_models: - base_model = model.id.split(':')[1] + base_model = model.id.split(":")[1] base_model_schema = None - for predefined_model_name, predefined_model in predefined_models_map.items(): + for ( + predefined_model_name, + predefined_model, + ) in predefined_models_map.items(): if predefined_model_name in base_model: base_model_schema = predefined_model @@ -328,30 +405,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ai_model_entity = AIModelEntity( model=model.id, - label=I18nObject( - zh_Hans=model.id, - en_US=model.id - ), + label=I18nObject(zh_Hans=model.id, en_US=model.id), model_type=ModelType.LLM, features=base_model_schema.features, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=base_model_schema.model_properties, parameter_rules=base_model_schema.parameter_rules, pricing=PriceConfig( - input=0.003, - output=0.006, - unit=0.001, - currency='USD' - ) + input=0.003, output=0.006, unit=0.001, currency="USD" + ), ) ai_model_entities.append(ai_model_entity) return ai_model_entities - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -373,10 +451,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # text completion model response = client.completions.create( @@ -384,16 +462,25 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, response, prompt_messages + ) - return self._handle_generate_response(model, credentials, response, prompt_messages) + return self._handle_generate_response( + model, credentials, response, prompt_messages + ) - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + response: Completion, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm completion response @@ -406,9 +493,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -417,11 +502,15 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): completion_tokens = response.usage.completion_tokens else: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + prompt_tokens = self._num_tokens_from_string( + model, prompt_messages[0].content + ) completion_tokens = self._num_tokens_from_string(model, assistant_text) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response result = LLMResult( @@ -434,8 +523,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[Completion], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm completion stream response @@ -445,21 +539,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text if delta.text else "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -471,11 +563,15 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): completion_tokens = chunk.usage.completion_tokens else: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + prompt_tokens = self._num_tokens_from_string( + model, prompt_messages[0].content + ) completion_tokens = self._num_tokens_from_string(model, full_text) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=chunk.model, @@ -485,8 +581,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( @@ -496,13 +592,20 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -535,17 +638,20 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if tools: # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - extra_model_kwargs['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in tools + ] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # chat model response = client.chat.completions.create( @@ -557,13 +663,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ) if stream: - return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) + return self._handle_chat_generate_stream_response( + model, credentials, response, prompt_messages, tools + ) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + return self._handle_chat_generate_response( + model, credentials, response, prompt_messages, tools + ) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -580,13 +695,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # extract tool calls from response # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - function_call = self._extract_response_function_call(assistant_message_function_call) + function_call = self._extract_response_function_call( + assistant_message_function_call + ) tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls + content=assistant_message.content, tool_calls=tool_calls ) # calculate num tokens @@ -596,11 +712,17 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): completion_tokens = response.usage.completion_tokens else: # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) - completion_tokens = self._num_tokens_from_messages(model, [assistant_prompt_message]) + prompt_tokens = self._num_tokens_from_messages( + model, prompt_messages, tools + ) + completion_tokens = self._num_tokens_from_messages( + model, [assistant_prompt_message] + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response response = LLMResult( @@ -613,9 +735,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -625,7 +752,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None for chunk in response: if len(chunk.choices) == 0: @@ -634,8 +761,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -646,43 +776,56 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # handle process of stream function call if assistant_message_function_call: # message has not ended ever - delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments + delta_assistant_message_function_call_storage.arguments += ( + assistant_message_function_call.arguments + ) continue else: # message has ended - assistant_message_function_call = delta_assistant_message_function_call_storage + assistant_message_function_call = ( + delta_assistant_message_function_call_storage + ) delta_assistant_message_function_call_storage = None else: if assistant_message_function_call: # start of stream function call - delta_assistant_message_function_call_storage = assistant_message_function_call + delta_assistant_message_function_call_storage = ( + assistant_message_function_call + ) if not has_finish_reason: continue # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - function_call = self._extract_response_function_call(assistant_message_function_call) + function_call = self._extract_response_function_call( + assistant_message_function_call + ) tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls + content=delta.delta.content if delta.delta.content else "", + tool_calls=tool_calls, ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if has_finish_reason: # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) + prompt_tokens = self._num_tokens_from_messages( + model, prompt_messages, tools + ) full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=tool_calls + content=full_assistant_content, tool_calls=tool_calls + ) + completion_tokens = self._num_tokens_from_messages( + model, [full_assistant_prompt_message] ) - completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=chunk.model, @@ -692,8 +835,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( @@ -703,12 +846,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, + response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall], + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -720,20 +864,21 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + arguments=response_tool_call.function.arguments, ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call.id, type=response_tool_call.type, - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -744,13 +889,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_function_call.name, - arguments=response_function_call.arguments + arguments=response_function_call.arguments, ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -767,20 +910,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): sub_messages = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + message_content = cast( + TextPromptMessageContent, message_content + ) sub_message_dict = { "type": "text", - "text": message_content.data + "text": message_content.data, } sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) + message_content = cast( + ImagePromptMessageContent, message_content + ) sub_message_dict = { "type": "image_url", "image_url": { "url": message_content.data, - "detail": message_content.detail.value - } + "detail": message_content.detail.value, + }, } sub_messages.append(sub_message_dict) @@ -809,7 +956,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): message_dict = { "role": "function", "content": message.content, - "name": message.tool_call_id + "name": message.tool_call_id, } else: raise ValueError(f"Got unknown type {message}") @@ -819,8 +966,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -841,14 +989,18 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if model.startswith('ft:'): - model = model.split(':')[1] + if model.startswith("ft:"): + model = model.split(":")[1] try: encoding = tiktoken.encoding_for_model(model) @@ -883,10 +1035,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -915,7 +1067,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_for_tools(self, encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: + def _num_tokens_for_tools( + self, encoding: tiktoken.Encoding, tools: list[PromptMessageTool] + ) -> int: """ Calculate num tokens for tool calling with tiktoken package. @@ -925,64 +1079,66 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) + num_tokens += len(encoding.encode("type")) num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) return num_tokens - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: """ - OpenAI supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + OpenAI supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ - if not model.startswith('ft:'): + if not model.startswith("ft:"): base_model = model else: # get base_model - base_model = model.split(':')[1] + base_model = model.split(":")[1] # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} if base_model not in model_map: - raise ValueError(f'Base model {base_model} not found') - + raise ValueError(f"Base model {base_model} not found") + base_model_schema = model_map[base_model] base_model_schema_features = base_model_schema.features or [] @@ -991,18 +1147,16 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=[feature for feature in base_model_schema_features], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - key: property for key, property in base_model_schema_model_properties.items() + key: property + for key, property in base_model_schema_model_properties.items() }, parameter_rules=[rule for rule in base_model_schema_parameters_rules], - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) return entity diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/moderation/moderation.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/moderation/moderation.py index 7301399c..8ff1958a 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -4,9 +4,15 @@ from openai import OpenAI from openai.types import ModerationCreateResponse from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.moderation_model import ( + ModerationModel, +) +from model_providers.core.model_runtime.model_providers.openai._common import ( + _CommonOpenAI, +) class OpenAIModerationModel(_CommonOpenAI, ModerationModel): @@ -14,9 +20,9 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): Model class for OpenAI text moderation model. """ - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke( + self, model: str, credentials: dict, text: str, user: Optional[str] = None + ) -> bool: """ Invoke moderation model @@ -34,13 +40,18 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): # chars per chunk length = self._get_max_characters_per_chunk(model, credentials) - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] max_text_chunks = self._get_max_chunks(model, credentials) - chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] + chunks = [ + text_chunks[i : i + max_text_chunks] + for i in range(0, len(text_chunks), max_text_chunks) + ] for text_chunk in chunks: - moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk) + moderation_result = self._moderation_invoke( + model=model, client=client, texts=text_chunk + ) for result in moderation_result.results: if result.flagged is True: @@ -65,12 +76,14 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): self._moderation_invoke( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _moderation_invoke(self, model: str, client: OpenAI, texts: list[str]) -> ModerationCreateResponse: + def _moderation_invoke( + self, model: str, client: OpenAI, texts: list[str] + ) -> ModerationCreateResponse: """ Invoke moderation model @@ -94,8 +107,14 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + if ( + model_schema + and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK + in model_schema.model_properties + ): + return model_schema.model_properties[ + ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK + ] return 2000 @@ -109,7 +128,10 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: + if ( + model_schema + and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + ): return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] return 1 diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/openai.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/openai.py index a6fe87e4..e706bad5 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai/openai.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/openai.py @@ -1,14 +1,17 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -22,11 +25,12 @@ class OpenAIProvider(ModelProvider): # Use `gpt-3.5-turbo` model for validate, # no matter what model you pass in, text completion model or chat model model_instance.validate_credentials( - model='gpt-3.5-turbo', - credentials=credentials + model="gpt-3.5-turbo", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/speech2text/speech2text.py index 0570d55f..dfd47595 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -2,9 +2,15 @@ from typing import IO, Optional from openai import OpenAI -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.speech2text_model import ( + Speech2TextModel, +) +from model_providers.core.model_runtime.model_providers.openai._common import ( + _CommonOpenAI, +) class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): @@ -12,9 +18,9 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke( + self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None + ) -> str: """ Invoke speech2text model @@ -37,12 +43,14 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: + def _speech2text_invoke( + self, model: str, credentials: dict, file: IO[bytes] + ) -> str: """ Invoke speech2text model diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 14f95eed..ef04e13f 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -7,10 +7,19 @@ import tiktoken from openai import OpenAI from model_providers.core.model_runtime.entities.model_entities import PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.openai._common import ( + _CommonOpenAI, +) class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): @@ -18,9 +27,13 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): Model class for OpenAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,9 +50,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" # get model properties context_size = self._get_context_size(model, credentials) @@ -56,11 +69,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -71,8 +82,8 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): embeddings_batch, embedding_used_tokens = self._embedding_invoke( model=model, client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + texts=tokens[i : i + max_chunks], + extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens @@ -91,7 +102,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): model=model, client=client, texts="", - extra_model_kwargs=extra_model_kwargs + extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens @@ -102,16 +113,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): # calc usage usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens + model=model, credentials=credentials, tokens=used_tokens ) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -153,16 +158,18 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): # call embedding model self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} + model=model, client=client, texts=["ping"], extra_model_kwargs={} ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + self, + model: str, + client: OpenAI, + texts: Union[list[str], str], + extra_model_kwargs: dict, + ) -> tuple[list[list[float]], int]: """ Invoke embedding model @@ -179,14 +186,26 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if ( + "encoding_format" in extra_model_kwargs + and extra_model_kwargs["encoding_format"] == "base64" + ): # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [ + list( + np.frombuffer(base64.b64decode(data.embedding), dtype="float32") + ) + for data in response.data + ], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -200,7 +219,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -211,7 +230,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/tts/tts.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/tts/tts.py index c44b6ca7..49267499 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai/tts/tts.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/tts/tts.py @@ -3,13 +3,18 @@ from functools import reduce from io import BytesIO from typing import Optional +from fastapi.responses import StreamingResponse from openai import OpenAI from pydub import AudioSegment -from fastapi.responses import StreamingResponse + from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel -from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI +from model_providers.core.model_runtime.model_providers.openai._common import ( + _CommonOpenAI, +) from model_providers.extensions.ext_storage import storage @@ -18,8 +23,16 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + streaming: bool, + user: Optional[str] = None, + ) -> any: """ _invoke text2speech model @@ -33,18 +46,33 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] + for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) if streaming: - return StreamingResponse(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - tenant_id=tenant_id, - voice=voice), media_type='text/event-stream') + return StreamingResponse( + self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text=content_text, + tenant_id=tenant_id, + voice=voice, + ), + media_type="text/event-stream", + ) else: - return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) + return self._tts_invoke( + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) - def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + def validate_credentials( + self, model: str, credentials: dict, user: Optional[str] = None + ) -> None: """ validate credentials text2speech model @@ -57,13 +85,15 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): self._tts_invoke( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse: + def _tts_invoke( + self, model: str, credentials: dict, content_text: str, voice: str + ) -> StreamingResponse: """ _tts_invoke text2speech model @@ -77,13 +107,25 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): word_limit = self._get_model_word_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials) try: - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list( + self._split_text_into_sentences(text=content_text, limit=word_limit) + ) audio_bytes_list = list() # Create a thread pool and map the function to the list of sentences - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice, - credentials=credentials) for sentence in sentences] + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + futures = [ + executor.submit( + self._process_sentence, + sentence=sentence, + model=model, + voice=voice, + credentials=credentials, + ) + for sentence in sentences + ] for future in futures: try: if future.result(): @@ -92,8 +134,11 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): raise InvokeBadRequestError(str(ex)) if len(audio_bytes_list) > 0: - audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in - audio_bytes_list if audio_bytes] + audio_segments = [ + AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) + for audio_bytes in audio_bytes_list + if audio_bytes + ] combined_segment = reduce(lambda x, y: x + y, audio_segments) buffer: BytesIO = BytesIO() combined_segment.export(buffer, format=audio_type) @@ -103,8 +148,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): raise InvokeBadRequestError(str(ex)) # Todo: To improve the streaming function - def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + ) -> any: """ _tts_invoke_streaming text2speech model @@ -117,24 +168,29 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): """ # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + if not voice or voice not in self.get_tts_model_voices( + model=model, credentials=credentials + ): voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) audio_type = self._get_model_audio_type(model, credentials) tts_file_id = self._get_file_name(content_text) - file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' + file_path = f"generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}" try: client = OpenAI(**credentials_kwargs) - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list( + self._split_text_into_sentences(text=content_text, limit=word_limit) + ) for sentence in sentences: - response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + response = client.audio.speech.create( + model=model, voice=voice, input=sentence.strip() + ) # response.stream_to_file(file_path) storage.save(file_path, response.read()) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api @@ -147,6 +203,8 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) - response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + response = client.audio.speech.create( + model=model, voice=voice, input=sentence.strip() + ) if isinstance(response.read(), bytes): return response.read() diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/_common.py index 81bb3267..d5c3d879 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,4 +1,3 @@ - import requests from model_providers.core.model_runtime.errors.invoke import ( @@ -35,10 +34,10 @@ class _CommonOAI_API_Compat: ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index e33c185d..05a2c7c9 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -8,7 +8,12 @@ from urllib.parse import urljoin import requests from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -33,9 +38,15 @@ from model_providers.core.model_runtime.entities.model_entities import ( PriceConfig, ) from model_providers.core.model_runtime.errors.invoke import InvokeError -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.openai_api_compatible._common import ( + _CommonOAI_API_Compat, +) from model_providers.core.model_runtime.utils import helper logger = logging.getLogger(__name__) @@ -46,11 +57,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -74,11 +91,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -99,78 +121,80 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials['endpoint_url'] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials["endpoint_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - endpoint_url = urljoin(endpoint_url, 'chat/completions') + endpoint_url = urljoin(endpoint_url, "chat/completions") elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - endpoint_url = urljoin(endpoint_url, 'completions') + data["prompt"] = "ping" + endpoint_url = urljoin(endpoint_url, "completions") else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 60) + endpoint_url, headers=headers, json=data, timeout=(10, 60) ) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError( + "Credentials validation failed: JSON decode error" + ) - if (completion_type is LLMMode.CHAT - and ('object' not in json_result or json_result['object'] != 'chat.completion')): + if completion_type is LLMMode.CHAT and ( + "object" not in json_result + or json_result["object"] != "chat.completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'chat.completion\'') - elif (completion_type is LLMMode.COMPLETION - and ('object' not in json_result or json_result['object'] != 'text_completion')): + "Credentials validation failed: invalid response object, must be 'chat.completion'" + ) + elif completion_type is LLMMode.COMPLETION and ( + "object" not in json_result + or json_result["object"] != "text_completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'text_completion\'') + "Credentials validation failed: invalid response object, must be 'text_completion'" + ) except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError( + f"An error occurred during credentials validation: {str(ex)}" + ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ support_function_call = False features = [] - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'function_call': + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "function_call": features = [ModelFeature.TOOL_CALL] support_function_call = True endpoint_url = credentials["endpoint_url"] @@ -185,43 +209,45 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features if support_function_call else [], model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), - ModelPropertyKey.MODE: credentials.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int( + credentials.get("context_size", "4096") + ), + ModelPropertyKey.MODE: credentials.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(credentials.get('temperature', 0.7)), + default=float(credentials.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(credentials.get('top_p', 1)), + default=float(credentials.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -229,31 +255,40 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): type=ParameterType.INT, default=512, min=1, - max=int(credentials.get('max_tokens_to_sample', 4096)), - ) + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), ], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), ), ) - if credentials['mode'] == 'chat': + if credentials["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials['mode'] == 'completion': + elif credentials["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: - raise ValueError(f"Unknown completion type {credentials['completion_type']}") + raise ValueError( + f"Unknown completion type {credentials['completion_type']}" + ) return entity # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -267,50 +302,53 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials["endpoint_url"] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + if not endpoint_url.endswith("/"): + endpoint_url += "/" - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, 'chat/completions') - data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] + endpoint_url = urljoin(endpoint_url, "chat/completions") + data["messages"] = [ + self._convert_prompt_message_to_dict(m) for m in prompt_messages + ] elif completion_type is LLMMode.COMPLETION: - endpoint_url = urljoin(endpoint_url, 'completions') - data['prompt'] = prompt_messages[0].content + endpoint_url = urljoin(endpoint_url, "completions") + data["prompt"] = prompt_messages[0].content else: raise ValueError("Unsupported completion type for model configuration.") # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: - formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) + formatted_tools.append( + helper.dump_model(PromptMessageFunction(function=tool)) + ) data["tools"] = formatted_tools @@ -321,26 +359,33 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): data["user"] = user response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 60), - stream=stream + endpoint_url, headers=headers, json=data, timeout=(10, 60), stream=stream ) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if response.status_code != 200: - raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") + raise InvokeError( + f"API request failed with status code {response.status_code}: {response.text}" + ) if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, response, prompt_messages + ) - return self._handle_generate_response(model, credentials, response, prompt_messages) + return self._handle_generate_response( + model, credentials, response, prompt_messages + ) - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: requests.Response, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -350,17 +395,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + prompt_tokens = self._num_tokens_from_string( + model, prompt_messages[0].content + ) + completion_tokens = self._num_tokens_from_string( + model, full_assistant_content + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) return LLMResultChunk( model=model, @@ -369,21 +421,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): index=index, message=message, finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) # delimiter for stream response, need unicode_escape import codecs + delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -392,45 +445,49 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') - if delta_content is None or delta_content == '': + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") + if delta_content is None or delta_content == "": continue - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response if assistant_message_tool_calls: - tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + tool_calls = self._extract_response_tool_calls( + assistant_message_tool_calls + ) # function_call = self._extract_response_function_call(assistant_message_function_call) # tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + tool_calls=tool_calls if assistant_message_tool_calls else [], ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=choice_text) + assistant_prompt_message = AssistantPromptMessage( + content=choice_text + ) full_assistant_content += choice_text else: continue @@ -440,7 +497,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): yield create_final_llm_result_chunk( index=chunk_index, message=assistant_prompt_message, - finish_reason=finish_reason + finish_reason=finish_reason, ) else: yield LLMResultChunk( @@ -449,40 +506,50 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> LLMResult: - + def _handle_generate_response( + self, + model: str, + credentials: dict, + response: requests.Response, + prompt_messages: list[PromptMessage], + ) -> LLMResult: response_json = response.json() - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) - output = response_json['choices'][0] + output = response_json["choices"][0] - response_content = '' + response_content = "" tool_calls = None - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") if completion_type is LLMMode.CHAT: - response_content = output.get('message', {})['content'] - if function_calling_type == 'tool_call': - tool_calls = output.get('message', {}).get('tool_calls') - elif function_calling_type == 'function_call': - tool_calls = output.get('message', {}).get('function_call') + response_content = output.get("message", {})["content"] + if function_calling_type == "tool_call": + tool_calls = output.get("message", {}).get("tool_calls") + elif function_calling_type == "function_call": + tool_calls = output.get("message", {}).get("function_call") elif completion_type is LLMMode.COMPLETION: - response_content = output['text'] + response_content = output["text"] - assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) + assistant_message = AssistantPromptMessage( + content=response_content, tool_calls=[] + ) if tool_calls: - if function_calling_type == 'tool_call': - assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) - elif function_calling_type == 'function_call': - assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] + if function_calling_type == "tool_call": + assistant_message.tool_calls = self._extract_response_tool_calls( + tool_calls + ) + elif function_calling_type == "function_call": + assistant_message.tool_calls = [ + self._extract_response_function_call(tool_calls) + ] usage = response_json.get("usage") if usage: @@ -491,11 +558,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): completion_tokens = usage["completion_tokens"] else: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, assistant_message.content) + prompt_tokens = self._num_tokens_from_string( + model, prompt_messages[0].content + ) + completion_tokens = self._num_tokens_from_string( + model, assistant_message.content + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response result = LLMResult( @@ -522,17 +595,19 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message_content = cast(PromptMessageContent, message_content) sub_message_dict = { "type": "text", - "text": message_content.data + "text": message_content.data, } sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) + message_content = cast( + ImagePromptMessageContent, message_content + ) sub_message_dict = { "type": "image_url", "image_url": { "url": message_content.data, - "detail": message_content.detail.value - } + "detail": message_content.detail.value, + }, } sub_messages.append(sub_message_dict) @@ -563,7 +638,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): message_dict = { "role": "function", "content": message.content, - "name": message.tool_call_id + "name": message.tool_call_id, } else: raise ValueError(f"Got unknown type {message}") @@ -573,8 +648,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return message_dict - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Approximate num tokens for model with gpt2 tokenizer. @@ -590,8 +666,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Approximate num tokens with GPT2 tokenizer. """ @@ -610,10 +690,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -651,46 +731,46 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): """ num_tokens = 0 for tool in tools: - num_tokens += self._get_num_tokens_by_gpt2('type') - num_tokens += self._get_num_tokens_by_gpt2('function') - num_tokens += self._get_num_tokens_by_gpt2('function') + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2("function") + num_tokens += self._get_num_tokens_by_gpt2("function") # calculate num tokens for function object - num_tokens += self._get_num_tokens_by_gpt2('name') + num_tokens += self._get_num_tokens_by_gpt2("name") num_tokens += self._get_num_tokens_by_gpt2(tool.name) - num_tokens += self._get_num_tokens_by_gpt2('description') + num_tokens += self._get_num_tokens_by_gpt2("description") num_tokens += self._get_num_tokens_by_gpt2(tool.description) parameters = tool.parameters - num_tokens += self._get_num_tokens_by_gpt2('parameters') - if 'title' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('title') + num_tokens += self._get_num_tokens_by_gpt2("parameters") + if "title" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("title") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) - num_tokens += self._get_num_tokens_by_gpt2('type') + num_tokens += self._get_num_tokens_by_gpt2("type") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) - if 'properties' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("properties") + for key, value in parameters.get("properties").items(): num_tokens += self._get_num_tokens_by_gpt2(key) for field_key, field_value in value.items(): num_tokens += self._get_num_tokens_by_gpt2(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(enum_field) else: num_tokens += self._get_num_tokens_by_gpt2(field_key) num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) - if 'required' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[dict] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -702,20 +782,21 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call["function"]["name"], - arguments=response_tool_call["function"]["arguments"] + arguments=response_tool_call["function"]["arguments"], ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"], type=response_tool_call["type"], - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -725,14 +806,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call['name'], - arguments=response_function_call['arguments'] + name=response_function_call["name"], + arguments=response_function_call["arguments"], ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call['name'], - type="function", - function=function + id=response_function_call["name"], type="function", function=function ) return tool_call diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py index 84e48463..645d7010 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py @@ -1,11 +1,12 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class OAICompatProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 78a0846e..0c71cfa0 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -16,10 +16,19 @@ from model_providers.core.model_runtime.entities.model_entities import ( PriceConfig, PriceType, ) -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.openai_api_compatible._common import ( + _CommonOAI_API_Compat, +) class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): @@ -27,9 +36,13 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,27 +52,25 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - - # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } - api_key = credentials.get('api_key') + # Prepare headers and payload for the request + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -70,7 +81,6 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -78,7 +88,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): if num_tokens >= context_size: cutoff = int(len(text) * (np.floor(context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -89,9 +99,9 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): for i in _iter: # Prepare the payload for the request payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs + "input": inputs[i : i + max_chunks], + "model": model, + **extra_model_kwargs, } # Make the request to the OpenAI API @@ -99,30 +109,26 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): endpoint_url, headers=headers, data=json.dumps(payload), - timeout=(10, 300) + timeout=(10, 300), ) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens + model=model, credentials=credentials, tokens=used_tokens ) - + return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model + embeddings=batched_embeddings, usage=usage, model=model ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -145,53 +151,54 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} response = requests.post( url=endpoint_url, headers=headers, data=json.dumps(payload), - timeout=(10, 300) + timeout=(10, 300), ) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') - - if 'model' not in json_result: raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + "Credentials validation failed: JSON decode error" + ) + + if "model" not in json_result: + raise CredentialsValidateFailedError( + "Credentials validation failed: invalid response" + ) except CredentialsValidateFailedError: raise except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -199,21 +206,22 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -227,7 +235,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -238,7 +246,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/llm.py index 84177585..271eca7e 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/llm.py @@ -1,7 +1,12 @@ from collections.abc import Generator from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -24,9 +29,16 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.openllm.llm.openllm_generate import OpenLLMGenerate, OpenLLMGenerateMessage +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.openllm.llm.openllm_generate import ( + OpenLLMGenerate, + OpenLLMGenerateMessage, +) from model_providers.core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import ( BadRequestError, InsufficientAccountBalanceError, @@ -38,88 +50,149 @@ from model_providers.core.model_runtime.model_providers.openllm.llm.openllm_gene class OpenLLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials for Baichuan model """ - if not credentials.get('server_url'): - raise CredentialsValidateFailedError('Invalid server URL') + if not credentials.get("server_url"): + raise CredentialsValidateFailedError("Invalid server URL") # ping instance = OpenLLMGenerate() try: instance.generate( - server_url=credentials['server_url'], - model_name=model, + server_url=credentials["server_url"], + model_name=model, prompt_messages=[ - OpenLLMGenerateMessage(content='ping\nAnswer: ', role='user') + OpenLLMGenerateMessage(content="ping\nAnswer: ", role="user") ], model_parameters={ - 'max_tokens': 64, - 'temperature': 0.8, - 'top_p': 0.9, - 'top_k': 15, + "max_tokens": 64, + "temperature": 0.8, + "top_p": 0.9, + "top_k": 15, }, stream=False, - user='', + user="", stop=[], ) except InvalidAuthenticationError as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool] + ) -> int: """ - Calculate num tokens for OpenLLM model - it's a generate model, so we just join them by spe + Calculate num tokens for OpenLLM model + it's a generate model, so we just join them by spe """ - messages = ','.join([message.content for message in messages]) + messages = ",".join([message.content for message in messages]) return self._get_num_tokens_by_gpt2(messages) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( model_name=model, - server_url=credentials['server_url'], - prompt_messages=[self._convert_prompt_message_to_openllm_message(message) for message in prompt_messages], + server_url=credentials["server_url"], + prompt_messages=[ + self._convert_prompt_message_to_openllm_message(message) + for message in prompt_messages + ], model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, + prompt_messages=prompt_messages, + credentials=credentials, + response=response, + ) + return self._handle_chat_generate_response( + model=model, + prompt_messages=prompt_messages, + credentials=credentials, + response=response, + ) - def _convert_prompt_message_to_openllm_message(self, prompt_message: PromptMessage) -> OpenLLMGenerateMessage: + def _convert_prompt_message_to_openllm_message( + self, prompt_message: PromptMessage + ) -> OpenLLMGenerateMessage: """ - convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface + convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface """ if isinstance(prompt_message, UserPromptMessage): - return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.USER.value, content=prompt_message.content) + return OpenLLMGenerateMessage( + role=OpenLLMGenerateMessage.Role.USER.value, + content=prompt_message.content, + ) elif isinstance(prompt_message, AssistantPromptMessage): - return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content) + return OpenLLMGenerateMessage( + role=OpenLLMGenerateMessage.Role.ASSISTANT.value, + content=prompt_message.content, + ) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError( + f"Prompt message type {type(prompt_message)} is not supported" + ) - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: OpenLLMGenerateMessage, + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -130,15 +203,20 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[OpenLLMGenerateMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[OpenLLMGenerateMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, @@ -146,11 +224,12 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=message.content, tool_calls=[] ), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason + if message.stop_reason + else None, ), ) else: @@ -160,72 +239,60 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=message.content, tool_calls=[] ), - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason + if message.stop_reason + else None, ), ) - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', + name="top_k", type=ParameterType.INT, - use_template='top_k', + use_template="top_k", min=1, default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + label=I18nObject(zh_Hans="Top K", en_US="Top K"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ + model_properties={ ModelPropertyKey.MODE: LLMMode.COMPLETION.value, }, - parameter_rules=rules + parameter_rules=rules, ) return entity @@ -241,22 +308,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 79d0c478..05dc9488 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -15,32 +15,38 @@ from model_providers.core.model_runtime.model_providers.openllm.llm.openllm_gene class OpenLLMGenerateMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - - def __init__(self, content: str, role: str = 'user') -> None: + + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], - stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, + self, + server_url: str, + model_name: str, + stream: bool, + model_parameters: dict[str, Any], + stop: list[str], + prompt_messages: list[OpenLLMGenerateMessage], + user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: - raise InvalidAuthenticationError('Invalid server URL') + raise InvalidAuthenticationError("Invalid server URL") default_llm_config = { "max_new_tokens": 128, @@ -72,51 +78,63 @@ class OpenLLMGenerate: "frequency_penalty": 0, "use_beam_search": False, "ignore_eos": False, - "skip_special_tokens": True + "skip_special_tokens": True, } - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - default_llm_config['max_new_tokens'] = model_parameters['max_tokens'] + if ( + "max_tokens" in model_parameters + and type(model_parameters["max_tokens"]) == int + ): + default_llm_config["max_new_tokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - default_llm_config['temperature'] = model_parameters['temperature'] + if ( + "temperature" in model_parameters + and type(model_parameters["temperature"]) == float + ): + default_llm_config["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - default_llm_config['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + default_llm_config["top_p"] = model_parameters["top_p"] - if 'top_k' in model_parameters and type(model_parameters['top_k']) == int: - default_llm_config['top_k'] = model_parameters['top_k'] + if "top_k" in model_parameters and type(model_parameters["top_k"]) == int: + default_llm_config["top_k"] = model_parameters["top_k"] - if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool: - default_llm_config['use_cache'] = model_parameters['use_cache'] + if ( + "use_cache" in model_parameters + and type(model_parameters["use_cache"]) == bool + ): + default_llm_config["use_cache"] = model_parameters["use_cache"] - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + headers = {"Content-Type": "application/json", "accept": "application/json"} if stream: - url = f'{server_url}/v1/generate_stream' + url = f"{server_url}/v1/generate_stream" timeout = 10 else: - url = f'{server_url}/v1/generate' + url = f"{server_url}/v1/generate" timeout = 120 data = { - 'stop': stop if stop else [], - 'prompt': '\n'.join([message.content for message in prompt_messages]), - 'llm_config': default_llm_config, + "stop": stop if stop else [], + "prompt": "\n".join([message.content for message in prompt_messages]), + "llm_config": default_llm_config, } try: - response = post(url=url, data=dumps(data), timeout=timeout, stream=stream, headers=headers) + response = post( + url=url, + data=dumps(data), + timeout=timeout, + stream=stream, + headers=headers, + ) except (ConnectionError, InvalidSchema, MissingSchema) as e: # cloud not connect to the server raise InvalidAuthenticationError(f"Invalid server URL: {e}") - + if not response.ok: resp = response.json() - msg = resp['msg'] + msg = resp["msg"] if response.status_code == 400: raise BadRequestError(msg) elif response.status_code == 404: @@ -125,69 +143,81 @@ class OpenLLMGenerate: raise InternalServerError(msg) else: raise InternalServerError(msg) - + if stream: return self._handle_chat_stream_generate_response(response) return self._handle_chat_generate_response(response) - - def _handle_chat_generate_response(self, response: Response) -> OpenLLMGenerateMessage: + + def _handle_chat_generate_response( + self, response: Response + ) -> OpenLLMGenerateMessage: try: data = response.json() except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InternalServerError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) - message = data['outputs'][0] - text = message['text'] - token_ids = message['token_ids'] - prompt_token_ids = data['prompt_token_ids'] - stop_reason = message['finish_reason'] + message = data["outputs"][0] + text = message["text"] + token_ids = message["token_ids"] + prompt_token_ids = data["prompt_token_ids"] + stop_reason = message["finish_reason"] - message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) + message = OpenLLMGenerateMessage( + content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value + ) message.stop_reason = stop_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': len(token_ids), - 'total_tokens': len(prompt_token_ids) + len(token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(token_ids), + "total_tokens": len(prompt_token_ids) + len(token_ids), } return message - def _handle_chat_stream_generate_response(self, response: Response) -> Generator[OpenLLMGenerateMessage, None, None]: + def _handle_chat_stream_generate_response( + self, response: Response + ) -> Generator[OpenLLMGenerateMessage, None, None]: completion_usage = 0 for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() - if line == '[DONE]': + if line == "[DONE]": return try: data = loads(line) except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") - - output = data['outputs'] + raise InternalServerError( + f"Failed to convert response to json: {e} with text: {line}" + ) + + output = data["outputs"] for choice in output: - text = choice['text'] - token_ids = choice['token_ids'] + text = choice["text"] + token_ids = choice["token_ids"] completion_usage += len(token_ids) - message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) + message = OpenLLMGenerateMessage( + content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value + ) - if 'finish_reason' in choice and choice['finish_reason']: - finish_reason = choice['finish_reason'] - prompt_token_ids = data['prompt_token_ids'] + if "finish_reason" in choice and choice["finish_reason"]: + finish_reason = choice["finish_reason"] + prompt_token_ids = data["prompt_token_ids"] message.stop_reason = finish_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': completion_usage, - 'total_tokens': completion_usage + len(prompt_token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": completion_usage, + "total_tokens": completion_usage + len(prompt_token_ids), } - + yield message diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py index d9d279e6..309b5cf4 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/openllm.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/openllm.py index 21c5fc22..2bbf290b 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/openllm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/openllm.py @@ -1,6 +1,8 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 920bfccc..5e1ed3a7 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -6,7 +6,10 @@ from requests import post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema from model_providers.core.model_runtime.entities.model_entities import PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -15,17 +18,26 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ Model class for OpenLLM text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -35,16 +47,13 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] if not server_url: - raise CredentialsValidateFailedError('server_url is required') - - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + raise CredentialsValidateFailedError("server_url is required") - url = f'{server_url}/v1/embeddings' + headers = {"Content-Type": "application/json", "accept": "application/json"} + + url = f"{server_url}/v1/embeddings" data = texts try: @@ -54,7 +63,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(f"Invalid server URL: {e}") except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: if response.status_code == 400: raise InvokeBadRequestError(response.text) @@ -62,22 +71,22 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): raise InvokeAuthorizationError(response.text) elif response.status_code == 500: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json()[0] - embeddings = resp['embeddings'] - total_tokens = resp['num_tokens'] + embeddings = resp["embeddings"] + total_tokens = resp["num_tokens"] except KeyError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=total_tokens ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) + return result def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -104,9 +113,9 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid server_url') + raise CredentialsValidateFailedError("Invalid server_url") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -119,24 +128,16 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -150,7 +151,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -161,7 +162,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/replicate/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/replicate/_common.py index 80a057fa..582cb8aa 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/replicate/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/replicate/_common.py @@ -1,15 +1,12 @@ from replicate.exceptions import ModelError, ReplicateError -from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError +from model_providers.core.model_runtime.errors.invoke import ( + InvokeBadRequestError, + InvokeError, +) class _CommonReplicate: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - ReplicateError, - ModelError - ] - } + return {InvokeBadRequestError: [ReplicateError, ModelError]} diff --git a/model-providers/model_providers/core/model_runtime/model_providers/replicate/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/replicate/llm/llm.py index 96be85e4..987cb4d0 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/replicate/llm/llm.py @@ -6,7 +6,12 @@ from replicate.exceptions import ReplicateError from replicate.prediction import Prediction from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -22,100 +27,152 @@ from model_providers.core.model_runtime.entities.model_entities import ( ModelType, ParameterRule, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.replicate._common import _CommonReplicate +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.replicate._common import ( + _CommonReplicate, +) class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + version = credentials["model_version"] - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: - - version = credentials['model_version'] - - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient( + api_token=credentials["replicate_api_token"], timeout=30 + ) model_info = client.models.get(model) model_info_version = model_info.versions.get(version) inputs = {**model_parameters} if prompt_messages[0].role == PromptMessageRole.SYSTEM: - if 'system_prompt' in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: - inputs['system_prompt'] = prompt_messages[0].content - inputs['prompt'] = prompt_messages[1].content + if ( + "system_prompt" + in model_info_version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ] + ): + inputs["system_prompt"] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[1].content else: - inputs['prompt'] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[0].content - prediction = client.predictions.create( - version=model_info_version, input=inputs - ) + prediction = client.predictions.create(version=model_info_version, input=inputs) if stream: - return self._handle_generate_stream_response(model, credentials, prediction, stop, prompt_messages) - return self._handle_generate_response(model, credentials, prediction, stop, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, prediction, stop, prompt_messages + ) + return self._handle_generate_response( + model, credentials, prediction, stop, prompt_messages + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError( + "Replicate Access Token must be provided." + ) - if 'model_version' not in credentials: - raise CredentialsValidateFailedError('Replicate Model Version must be provided.') + if "model_version" not in credentials: + raise CredentialsValidateFailedError( + "Replicate Model Version must be provided." + ) if model.count("/") != 1: - raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' - 'format: {user_name}/{model_name}') + raise CredentialsValidateFailedError( + "Replicate Model Name must be provided, " + "format: {user_name}/{model_name}" + ) - version = credentials['model_version'] + version = credentials["model_version"] try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient( + api_token=credentials["replicate_api_token"], timeout=30 + ) model_info = client.models.get(model) model_info_version = model_info.versions.get(version) self._check_text_generation_model(model_info_version, model, version) except ReplicateError as e: raise CredentialsValidateFailedError( - f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}") + f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}" + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @staticmethod def _check_text_generation_model(model_info_version, model_name, version): - if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: - raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.") + if ( + "temperature" + not in model_info_version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ] + or "top_p" + not in model_info_version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ] + or "top_k" + not in model_info_version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ] + ): + raise CredentialsValidateFailedError( + f"Model {model_name}:{version} is not a Text Generation model." + ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - model_type = LLMMode.CHAT if model.endswith('-chat') else LLMMode.COMPLETION + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: + model_type = LLMMode.CHAT if model.endswith("-chat") else LLMMode.COMPLETION entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: model_type.value - }, - parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) + model_properties={ModelPropertyKey.MODE: model_type.value}, + parameter_rules=self._get_customizable_model_parameter_rules( + model, credentials + ), ) return entity @classmethod - def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: - version = credentials['model_version'] + def _get_customizable_model_parameter_rules( + cls, model: str, credentials: dict + ) -> list[ParameterRule]: + version = credentials["model_version"] - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient( + api_token=credentials["replicate_api_token"], timeout=30 + ) model_info = client.models.get(model) model_info_version = model_info.versions.get(version) @@ -129,8 +186,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): ) for key, value in input_properties: - if key not in ['system_prompt', 'prompt'] and 'stop' not in key: - value_type = value.get('type') + if key not in ["system_prompt", "prompt"] and "stop" not in key: + value_type = value.get("type") if not value_type: continue @@ -139,28 +196,28 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): rule = ParameterRule( name=key, - label={ - 'en_US': value['title'] - }, + label={"en_US": value["title"]}, type=param_type, help={ - 'en_US': value.get('description'), + "en_US": value.get("description"), }, required=False, - default=value.get('default'), - min=value.get('minimum'), - max=value.get('maximum') + default=value.get("default"), + min=value.get("minimum"), + max=value.get("maximum"), ) parameter_rules.append(rule) return parameter_rules - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prediction: Prediction, - stop: list[str], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> Generator: index = -1 current_completion: str = "" stop_condition_reached = False @@ -171,7 +228,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): for output in prediction.output_iterator(): current_completion += output - if not is_prediction_output_finished and prediction.status == 'succeeded': + if not is_prediction_output_finished and prediction.status == "succeeded": prediction_output_length = len(prediction.output) - 1 is_prediction_output_finished = True @@ -190,7 +247,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): index += 1 assistant_prompt_message = AssistantPromptMessage( - content=output if output else '' + content=output if output else "" ) if index < prediction_output_length: @@ -198,28 +255,35 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + index=index, message=assistant_prompt_message + ), ) else: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + index=index, message=assistant_prompt_message, usage=usage + ), ) - def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> LLMResult: current_completion: str = "" stop_condition_reached = False for output in prediction.output_iterator(): @@ -237,14 +301,16 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): if stop_condition_reached: break - assistant_prompt_message = AssistantPromptMessage( - content=current_completion - ) + assistant_prompt_message = AssistantPromptMessage(content=current_completion) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) result = LLMResult( model=model, @@ -257,21 +323,20 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): @classmethod def _get_parameter_type(cls, param_type: str) -> str: - if param_type == 'integer': - return 'int' - elif param_type == 'number': - return 'float' - elif param_type == 'boolean': - return 'boolean' - elif param_type == 'string': - return 'string' + if param_type == "integer": + return "int" + elif param_type == "number": + return "float" + elif param_type == "boolean": + return "boolean" + elif param_type == "string": + return "string" def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list text = "".join( - self._convert_one_message_to_text(message) - for message in messages + self._convert_one_message_to_text(message) for message in messages ) return text.rstrip() diff --git a/model-providers/model_providers/core/model_runtime/model_providers/replicate/replicate.py b/model-providers/model_providers/core/model_runtime/model_providers/replicate/replicate.py index 77c0aca9..72521b9a 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/replicate/replicate.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/replicate/replicate.py @@ -1,11 +1,12 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class ReplicateProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 4a957a32..a6884360 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -5,33 +5,52 @@ from typing import Optional from replicate import Client as ReplicateClient from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.replicate._common import _CommonReplicate +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelType, + PriceType, +) +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.replicate._common import ( + _CommonReplicate, +) class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: - - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: + client = ReplicateClient( + api_token=credentials["replicate_api_token"], timeout=30 + ) replicate_model_version = f'{model}:{credentials["model_version"]}' - text_input_key = self._get_text_input_key(model, credentials['model_version'], client) + text_input_key = self._get_text_input_key( + model, credentials["model_version"], client + ) - embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - texts) + embeddings = self._generate_embeddings_by_text_input_key( + client, replicate_model_version, text_input_key, texts + ) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -40,40 +59,48 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError( + "Replicate Access Token must be provided." + ) - if 'model_version' not in credentials: - raise CredentialsValidateFailedError('Replicate Model Version must be provided.') + if "model_version" not in credentials: + raise CredentialsValidateFailedError( + "Replicate Model Version must be provided." + ) try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient( + api_token=credentials["replicate_api_token"], timeout=30 + ) replicate_model_version = f'{model}:{credentials["model_version"]}' - text_input_key = self._get_text_input_key(model, credentials['model_version'], client) + text_input_key = self._get_text_input_key( + model, credentials["model_version"], client + ) - self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - ['Hello worlds!']) + self._generate_embeddings_by_text_input_key( + client, replicate_model_version, text_input_key, ["Hello worlds!"] + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 4096, - 'max_chunks': 1 - } + model_properties={"context_size": 4096, "max_chunks": 1}, ) return entity @staticmethod - def _get_text_input_key(model: str, model_version: str, client: ReplicateClient) -> str: + def _get_text_input_key( + model: str, model_version: str, client: ReplicateClient + ) -> str: model_info = client.models.get(model) model_info_version = model_info.versions.get(model_version) @@ -86,42 +113,50 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): ) for input_property in input_properties: - if input_property[0] in ('text', 'texts', 'inputs'): + if input_property[0] in ("text", "texts", "inputs"): text_input_key = input_property[0] return text_input_key - return '' + return "" @staticmethod - def _generate_embeddings_by_text_input_key(client: ReplicateClient, replicate_model_version: str, - text_input_key: str, texts: list[str]) -> list[list[float]]: - - if text_input_key in ('text', 'inputs'): + def _generate_embeddings_by_text_input_key( + client: ReplicateClient, + replicate_model_version: str, + text_input_key: str, + texts: list[str], + ) -> list[list[float]]: + if text_input_key in ("text", "inputs"): embeddings = [] for text in texts: - result = client.run(replicate_model_version, input={ - text_input_key: text - }) - embeddings.append(result[0].get('embedding')) + result = client.run( + replicate_model_version, input={text_input_key: text} + ) + embeddings.append(result[0].get("embedding")) return [list(map(float, e)) for e in embeddings] - elif 'texts' == text_input_key: - result = client.run(replicate_model_version, input={ - 'texts': json.dumps(texts), - "batch_size": 4, - "convert_to_numpy": False, - "normalize_embeddings": True - }) + elif "texts" == text_input_key: + result = client.run( + replicate_model_version, + input={ + "texts": json.dumps(texts), + "batch_size": 4, + "convert_to_numpy": False, + "normalize_embeddings": True, + }, + ) return result else: - raise ValueError(f'embeddings input key is invalid: {text_input_key}') + raise ValueError(f"embeddings input key is invalid: {text_input_key}") - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: input_price_info = self.get_price( model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -132,7 +167,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/_client.py b/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/_client.py index a4659454..56a36262 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/_client.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/_client.py @@ -14,36 +14,31 @@ import websocket class SparkLLMClient: - def __init__(self, model: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - domain = 'spark-api.xf-yun.com' - endpoint = 'chat' + def __init__( + self, + model: str, + app_id: str, + api_key: str, + api_secret: str, + api_domain: Optional[str] = None, + ): + domain = "spark-api.xf-yun.com" + endpoint = "chat" if api_domain: domain = api_domain - if model == 'spark-v3': - endpoint = 'multimodal' + if model == "spark-v3": + endpoint = "multimodal" model_api_configs = { - 'spark-1.5': { - 'version': 'v1.1', - 'chat_domain': 'general' - }, - 'spark-2': { - 'version': 'v2.1', - 'chat_domain': 'generalv2' - }, - 'spark-3': { - 'version': 'v3.1', - 'chat_domain': 'generalv3' - }, - 'spark-3.5': { - 'version': 'v3.5', - 'chat_domain': 'generalv3.5' - } + "spark-1.5": {"version": "v1.1", "chat_domain": "general"}, + "spark-2": {"version": "v2.1", "chat_domain": "generalv2"}, + "spark-3": {"version": "v3.1", "chat_domain": "generalv3"}, + "spark-3.5": {"version": "v3.5", "chat_domain": "generalv3.5"}, } - api_version = model_api_configs[model]['version'] + api_version = model_api_configs[model]["version"] - self.chat_domain = model_api_configs[model]['chat_domain'] + self.chat_domain = model_api_configs[model]["chat_domain"] self.api_base = f"wss://{domain}/{api_version}/{endpoint}" self.app_id = app_id self.ws_url = self.create_url( @@ -51,13 +46,15 @@ class SparkLLMClient: urlparse(self.api_base).path, self.api_base, api_key, - api_secret + api_secret, ) self.queue = queue.Queue() - self.blocking_message = '' + self.blocking_message = "" - def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: + def create_url( + self, host: str, path: str, api_base: str, api_key: str, api_secret: str + ) -> str: # generate timestamp by RFC1123 now = datetime.now() date = format_date_time(mktime(now.timetuple())) @@ -67,33 +64,39 @@ class SparkLLMClient: signature_origin += "GET " + path + " HTTP/1.1" # encrypt using hmac-sha256 - signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() + signature_sha = hmac.new( + api_secret.encode("utf-8"), + signature_origin.encode("utf-8"), + digestmod=hashlib.sha256, + ).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( + encoding="utf-8" + ) - v = { - "authorization": authorization, - "date": date, - "host": host - } + v = {"authorization": authorization, "date": date, "host": host} # generate url - url = api_base + '?' + urlencode(v) + url = api_base + "?" + urlencode(v) return url - def run(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None, streaming: bool = False): + def run( + self, + messages: list, + user_id: str, + model_kwargs: Optional[dict] = None, + streaming: bool = False, + ): websocket.enableTrace(False) ws = websocket.WebSocketApp( self.ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, - on_open=self.on_open + on_open=self.on_open, ) ws.messages = messages ws.user_id = user_id @@ -102,85 +105,82 @@ class SparkLLMClient: ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_error(self, ws, error): - self.queue.put({ - 'status_code': error.status_code, - 'error': error.resp_body.decode('utf-8') - }) + self.queue.put( + {"status_code": error.status_code, "error": error.resp_body.decode("utf-8")} + ) ws.close() def on_close(self, ws, close_status_code, close_reason): - self.queue.put({'done': True}) + self.queue.put({"done": True}) def on_open(self, ws): - self.blocking_message = '' - data = json.dumps(self.gen_params( - messages=ws.messages, - user_id=ws.user_id, - model_kwargs=ws.model_kwargs - )) + self.blocking_message = "" + data = json.dumps( + self.gen_params( + messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs + ) + ) ws.send(data) def on_message(self, ws, message): data = json.loads(message) - code = data['header']['code'] + code = data["header"]["code"] if code != 0: - self.queue.put({ - 'status_code': 400, - 'error': f"Code: {code}, Error: {data['header']['message']}" - }) + self.queue.put( + { + "status_code": 400, + "error": f"Code: {code}, Error: {data['header']['message']}", + } + ) ws.close() else: choices = data["payload"]["choices"] status = choices["status"] content = choices["text"][0]["content"] if ws.streaming: - self.queue.put({'data': content}) + self.queue.put({"data": content}) else: self.blocking_message += content if status == 2: if not ws.streaming: - self.queue.put({'data': self.blocking_message}) + self.queue.put({"data": self.blocking_message}) ws.close() - def gen_params(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None) -> dict: + def gen_params( + self, messages: list, user_id: str, model_kwargs: Optional[dict] = None + ) -> dict: data = { - "header": { - "app_id": self.app_id, - "uid": user_id - }, - "parameter": { - "chat": { - "domain": self.chat_domain - } - }, - "payload": { - "message": { - "text": messages - } - } + "header": {"app_id": self.app_id, "uid": user_id}, + "parameter": {"chat": {"domain": self.chat_domain}}, + "payload": {"message": {"text": messages}}, } if model_kwargs: - data['parameter']['chat'].update(model_kwargs) + data["parameter"]["chat"].update(model_kwargs) return data def subscribe(self): while True: content = self.queue.get() - if 'error' in content: - if content['status_code'] == 401: - raise SparkError('[Spark] The credentials you provided are incorrect. ' - 'Please double-check and fill them in again.') - elif content['status_code'] == 403: - raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " - "Please try again after obtaining the necessary permissions.") + if "error" in content: + if content["status_code"] == 401: + raise SparkError( + "[Spark] The credentials you provided are incorrect. " + "Please double-check and fill them in again." + ) + elif content["status_code"] == 403: + raise SparkError( + "[Spark] Sorry, the credentials you provided are access denied. " + "Please try again after obtaining the necessary permissions." + ) else: - raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") + raise SparkError( + f"[Spark] code: {content['status_code']}, error: {content['error']}" + ) - if 'data' not in content: + if "data" not in content: break yield content diff --git a/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py index 383fb7c7..c7ea29f3 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py @@ -2,7 +2,11 @@ import threading from collections.abc import Generator from typing import Optional, Union -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -18,19 +22,28 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) from ._client import SparkLLMClient class SparkLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -45,10 +58,17 @@ class SparkLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + return self._generate( + model, credentials, prompt_messages, model_parameters, stop, stream, user + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -80,15 +100,21 @@ class SparkLargeLanguageModel(LargeLanguageModel): model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -103,7 +129,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -113,21 +139,40 @@ class SparkLargeLanguageModel(LargeLanguageModel): **credentials_kwargs, ) - thread = threading.Thread(target=client.run, args=( - [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], - user, - model_parameters, - stream - )) + thread = threading.Thread( + target=client.run, + args=( + [ + { + "role": prompt_message.role.value, + "content": prompt_message.content, + } + for prompt_message in prompt_messages + ], + user, + model_parameters, + stream, + ), + ) thread.start() if stream: - return self._handle_generate_stream_response(thread, model, credentials, client, prompt_messages) + return self._handle_generate_stream_response( + thread, model, credentials, client, prompt_messages + ) - return self._handle_generate_response(thread, model, credentials, client, prompt_messages) - - def _handle_generate_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> LLMResult: + return self._handle_generate_response( + thread, model, credentials, client, prompt_messages + ) + + def _handle_generate_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -140,7 +185,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): for content in client.subscribe(): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content @@ -148,16 +193,18 @@ class SparkLargeLanguageModel(LargeLanguageModel): thread.join() # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + assistant_prompt_message = AssistantPromptMessage(content=completion) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) # transform response result = LLMResult( @@ -168,9 +215,15 @@ class SparkLargeLanguageModel(LargeLanguageModel): ) return result - - def _handle_generate_stream_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> Generator: + + def _handle_generate_stream_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -183,27 +236,29 @@ class SparkLargeLanguageModel(LargeLanguageModel): """ for index, content in enumerate(client.subscribe()): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content assistant_prompt_message = AssistantPromptMessage( - content=delta if delta else '', + content=delta if delta else "", ) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens( + model, credentials, [assistant_prompt_message] + ) # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + index=index, message=assistant_prompt_message, usage=usage + ), ) thread.join() @@ -216,9 +271,9 @@ class SparkLargeLanguageModel(LargeLanguageModel): :return: """ credentials_kwargs = { - "app_id": credentials['app_id'], - "api_secret": credentials['api_secret'], - "api_key": credentials['api_key'], + "app_id": credentials["app_id"], + "api_secret": credentials["api_secret"], + "api_key": credentials["api_key"], } return credentials_kwargs @@ -244,7 +299,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): raise ValueError(f"Got unknown type {message}") return message_text - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model @@ -255,8 +310,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): messages = messages.copy() # don't mutate the original list text = "".join( - self._convert_one_message_to_text(message) - for message in messages + self._convert_one_message_to_text(message) for message in messages ) # trim off the trailing ' ' that might come from the "Assistant: " @@ -277,5 +331,5 @@ class SparkLargeLanguageModel(LargeLanguageModel): InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/spark/spark.py b/model-providers/model_providers/core/model_runtime/model_providers/spark/spark.py index 247d83bb..68ecd9fb 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/spark/spark.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/spark/spark.py @@ -1,6 +1,8 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py index 782ec228..c954affd 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -2,47 +2,88 @@ from collections.abc import Generator from typing import Optional, Union from model_providers.core.model_runtime.entities.llm_entities import LLMResult -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) from model_providers.core.model_runtime.entities.model_entities import AIModelEntity -from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel +from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( + OAIAPICompatLargeLanguageModel, +) class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - credentials['endpoint_url'] = "https://api.together.xyz/v1" + credentials["endpoint_url"] = "https://api.together.xyz/v1" return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._invoke( + model, + cred_with_endpoint, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) def validate_credentials(self, model: str, credentials: dict) -> None: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().validate_credentials(model, cred_with_endpoint) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, + cred_with_endpoint, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_customizable_model_schema(model, cred_with_endpoint) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) - - diff --git a/model-providers/model_providers/core/model_runtime/model_providers/togetherai/togetherai.py b/model-providers/model_providers/core/model_runtime/model_providers/togetherai/togetherai.py index 6226e9b8..22d0d09c 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/togetherai/togetherai.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/togetherai/togetherai.py @@ -1,11 +1,12 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) class TogetherAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/_common.py index aa2bf5c9..da62624a 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/_common.py @@ -5,7 +5,7 @@ class _CommonTongyi: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: credentials_kwargs = { - "dashscope_api_key": credentials['dashscope_api_key'], + "dashscope_api_key": credentials["dashscope_api_key"], } return credentials_kwargs diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/_client.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/_client.py index cfe33558..d7bf35f3 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/_client.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/_client.py @@ -10,10 +10,7 @@ class EnhanceTongyi(Tongyi): @property def _default_params(self) -> dict[str, Any]: """Get the default parameters for calling OpenAI API.""" - normal_params = { - "top_p": self.top_p, - "api_key": self.dashscope_api_key - } + normal_params = {"top_p": self.top_p, "api_key": self.dashscope_api_key} return {**normal_params, **self.model_kwargs} @@ -34,14 +31,14 @@ class EnhanceTongyi(Tongyi): if len(prompts) > 1: raise ValueError("Cannot stream results with multiple prompts.") params["stream"] = True - text = '' + text = "" for stream_resp in stream_generate_with_retry( self, prompt=prompts[0], **params ): if not generations: current_text = stream_resp["output"]["text"] else: - current_text = stream_resp["output"]["text"][len(text):] + current_text = stream_resp["output"]["text"][len(text) :] text = stream_resp["output"]["text"] diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py index f571ad76..f2b0741c 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -14,7 +14,12 @@ from dashscope.common.error import ( from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry from model_providers.core.model_runtime.callbacks.base_callback import Callback -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -30,19 +35,28 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) from ._client import EnhanceTongyi class TongyiLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -57,13 +71,22 @@ class TongyiLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - - def _code_block_mode_wrapper(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \ - -> LLMResult | Generator: + return self._generate( + model, credentials, prompt_messages, model_parameters, stop, stream, user + ) + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + callbacks: list[Callback] = None, + ) -> LLMResult | Generator: """ Wrapper for code block mode """ @@ -86,38 +109,46 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - + model_parameters.pop("response_format") stop = stop or [] stop.extend(["\n```", "```\n"]) block_prompts = block_prompts.replace("{{block}}", code_block) # check if there is a system message - if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + if len(prompt_messages) > 0 and isinstance( + prompt_messages[0], SystemPromptMessage + ): # override the system message prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", prompt_messages[0].content) + content=block_prompts.replace( + "{{instructions}}", prompt_messages[0].content + ) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} object.") - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=block_prompts.replace( + "{{instructions}}", + f"Please output a valid {code_block} object.", + ) + ), + ) mode = self.get_model_mode(model, credentials) if mode == LLMMode.CHAT: - if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + if len(prompt_messages) > 0 and isinstance( + prompt_messages[-1], UserPromptMessage + ): # add ```JSON\n to the last message prompt_messages[-1].content += f"\n```{code_block}\n" else: # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) + prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n")) else: prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n")) @@ -129,20 +160,23 @@ if you are not sure about the structure. tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if isinstance(response, Generator): return self._code_block_mode_stream_processor_with_backtick( - model=model, - prompt_messages=prompt_messages, - input_generator=response + model=model, prompt_messages=prompt_messages, input_generator=response ) - + return response - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -177,15 +211,21 @@ if you are not sure about the structure. model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -200,7 +240,7 @@ if you are not sure about the structure. """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -208,11 +248,11 @@ if you are not sure about the structure. client = EnhanceTongyi( model_name=model, streaming=stream, - dashscope_api_key=credentials_kwargs['api_key'], + dashscope_api_key=credentials_kwargs["api_key"], ) params = { - 'model': model, + "model": model, **model_parameters, **credentials_kwargs, **extra_model_kwargs, @@ -221,28 +261,36 @@ if you are not sure about the structure. mode = self.get_model_mode(model, credentials) if mode == LLMMode.CHAT: - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages( + prompt_messages + ) else: - params['prompt'] = self._convert_messages_to_prompt(prompt_messages) + params["prompt"] = self._convert_messages_to_prompt(prompt_messages) if stream: responses = stream_generate_with_retry( - client, - stream=True, - incremental_output=True, - **params + client, stream=True, incremental_output=True, **params ) - return self._handle_generate_stream_response(model, credentials, responses, prompt_messages) + return self._handle_generate_stream_response( + model, credentials, responses, prompt_messages + ) response = generate_with_retry( client, **params, ) - return self._handle_generate_response(model, credentials, response, prompt_messages) - - def _handle_generate_response(self, model: str, credentials: dict, response: DashScopeAPIResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + return self._handle_generate_response( + model, credentials, response, prompt_messages + ) + + def _handle_generate_response( + self, + model: str, + credentials: dict, + response: DashScopeAPIResponse, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -253,12 +301,15 @@ if you are not sure about the structure. :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.output.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.output.text) # transform usage - usage = self._calc_response_usage(model, credentials, response.usage.input_tokens, response.usage.output_tokens) + usage = self._calc_response_usage( + model, + credentials, + response.usage.input_tokens, + response.usage.output_tokens, + ) # transform response result = LLMResult( @@ -270,8 +321,13 @@ if you are not sure about the structure. return result - def _handle_generate_stream_response(self, model: str, credentials: dict, responses: Generator, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + responses: Generator, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -286,17 +342,21 @@ if you are not sure about the structure. resp_content = response.output.text usage = response.usage - if resp_finish_reason is None and (resp_content is None or resp_content == ''): + if resp_finish_reason is None and ( + resp_content is None or resp_content == "" + ): continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=resp_content if resp_content else '', + content=resp_content if resp_content else "", ) if resp_finish_reason is not None: # transform usage - usage = self._calc_response_usage(model, credentials, usage.input_tokens, usage.output_tokens) + usage = self._calc_response_usage( + model, credentials, usage.input_tokens, usage.output_tokens + ) yield LLMResultChunk( model=model, @@ -305,17 +365,16 @@ if you are not sure about the structure. index=index, message=assistant_prompt_message, finish_reason=resp_finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + index=index, message=assistant_prompt_message + ), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -326,7 +385,7 @@ if you are not sure about the structure. :return: """ credentials_kwargs = { - "api_key": credentials['dashscope_api_key'], + "api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -352,7 +411,7 @@ if you are not sure about the structure. raise ValueError(f"Got unknown type {message}") return message_text - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model @@ -363,14 +422,15 @@ if you are not sure about the structure. messages = messages.copy() # don't mutate the original list text = "".join( - self._convert_one_message_to_text(message) - for message in messages + self._convert_one_message_to_text(message) for message in messages ) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[PromptMessage]) -> list[dict]: + def _convert_prompt_messages_to_tongyi_messages( + self, prompt_messages: list[PromptMessage] + ) -> list[dict]: """ Convert prompt messages to tongyi messages @@ -380,20 +440,26 @@ if you are not sure about the structure. tongyi_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): - tongyi_messages.append({ - 'role': 'system', - 'content': prompt_message.content, - }) + tongyi_messages.append( + { + "role": "system", + "content": prompt_message.content, + } + ) elif isinstance(prompt_message, UserPromptMessage): - tongyi_messages.append({ - 'role': 'user', - 'content': prompt_message.content, - }) + tongyi_messages.append( + { + "role": "user", + "content": prompt_message.content, + } + ) elif isinstance(prompt_message, AssistantPromptMessage): - tongyi_messages.append({ - 'role': 'assistant', - 'content': prompt_message.content, - }) + tongyi_messages.append( + { + "role": "assistant", + "content": prompt_message.content, + } + ) else: raise ValueError(f"Got unknown type {prompt_message}") @@ -424,5 +490,5 @@ if you are not sure about the structure. InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tongyi.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tongyi.py index 25c7a2cf..195c1809 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tongyi.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tongyi.py @@ -1,8 +1,12 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) @@ -21,11 +25,12 @@ class TongyiProvider(ModelProvider): # Use `qwen-turbo` model for validate, model_instance.validate_credentials( - model='qwen-turbo', - credentials=credentials + model="qwen-turbo", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py index 7818ec3f..aa3e7f88 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -4,12 +4,17 @@ from io import BytesIO from typing import Optional import dashscope -from pydub import AudioSegment from fastapi.responses import StreamingResponse +from pydub import AudioSegment + from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel -from model_providers.core.model_runtime.model_providers.tongyi._common import _CommonTongyi +from model_providers.core.model_runtime.model_providers.tongyi._common import ( + _CommonTongyi, +) from model_providers.extensions.ext_storage import storage @@ -18,8 +23,16 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): Model class for Tongyi Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, - user: Optional[str] = None) -> any: + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + streaming: bool, + user: Optional[str] = None, + ) -> any: """ _invoke text2speech model @@ -33,18 +46,33 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] + for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) if streaming: - return StreamingResponse(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - tenant_id=tenant_id, - voice=voice), media_type='text/event-stream') + return StreamingResponse( + self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text=content_text, + tenant_id=tenant_id, + voice=voice, + ), + media_type="text/event-stream", + ) else: - return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) + return self._tts_invoke( + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) - def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + def validate_credentials( + self, model: str, credentials: dict, user: Optional[str] = None + ) -> None: """ validate credentials text2speech model @@ -57,13 +85,15 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): self._tts_invoke( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response: + def _tts_invoke( + self, model: str, credentials: dict, content_text: str, voice: str + ) -> Response: """ _tts_invoke text2speech model @@ -77,14 +107,25 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): word_limit = self._get_model_word_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials) try: - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list( + self._split_text_into_sentences(text=content_text, limit=word_limit) + ) audio_bytes_list = list() # Create a thread pool and map the function to the list of sentences - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(self._process_sentence, sentence=sentence, - credentials=credentials, voice=voice, audio_type=audio_type) for sentence in - sentences] + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + futures = [ + executor.submit( + self._process_sentence, + sentence=sentence, + credentials=credentials, + voice=voice, + audio_type=audio_type, + ) + for sentence in sentences + ] for future in futures: try: if future.result(): @@ -93,8 +134,11 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): raise InvokeBadRequestError(str(ex)) if len(audio_bytes_list) > 0: - audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in - audio_bytes_list if audio_bytes] + audio_segments = [ + AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) + for audio_bytes in audio_bytes_list + if audio_bytes + ] combined_segment = reduce(lambda x, y: x + y, audio_segments) buffer: BytesIO = BytesIO() combined_segment.export(buffer, format=audio_type) @@ -104,8 +148,14 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): raise InvokeBadRequestError(str(ex)) # Todo: To improve the streaming function - def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + ) -> any: """ _tts_invoke_streaming text2speech model @@ -116,25 +166,33 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param content_text: text content to be translated :return: text translated to audio file """ - dashscope.api_key = credentials.get('dashscope_api_key') + dashscope.api_key = credentials.get("dashscope_api_key") word_limit = self._get_model_word_limit(model, credentials) audio_type = self._get_model_audio_type(model, credentials) tts_file_id = self._get_file_name(content_text) - file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' + file_path = f"generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}" try: - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list( + self._split_text_into_sentences(text=content_text, limit=word_limit) + ) for sentence in sentences: - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, - text=sentence.strip(), - format=audio_type, word_timestamp_enabled=True, - phoneme_timestamp_enabled=True) + response = dashscope.audio.tts.SpeechSynthesizer.call( + model=voice, + sample_rate=48000, + text=sentence.strip(), + format=audio_type, + word_timestamp_enabled=True, + phoneme_timestamp_enabled=True, + ) if isinstance(response.get_audio_data(), bytes): storage.save(file_path, response.get_audio_data()) except Exception as ex: raise InvokeBadRequestError(str(ex)) @staticmethod - def _process_sentence(sentence: str, credentials: dict, voice: str, audio_type: str): + def _process_sentence( + sentence: str, credentials: dict, voice: str, audio_type: str + ): """ _tts_invoke Tongyi text2speech model api @@ -144,9 +202,9 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param audio_type: audio file type :return: text translated to audio file """ - dashscope.api_key = credentials.get('dashscope_api_key') - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, - text=sentence.strip(), - format=audio_type) + dashscope.api_key = credentials.get("dashscope_api_key") + response = dashscope.audio.tts.SpeechSynthesizer.call( + model=voice, sample_rate=48000, text=sentence.strip(), format=audio_type + ) if isinstance(response.get_audio_data(), bytes): return response.get_audio_data() diff --git a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index e3fcd9db..39464d4e 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -7,7 +7,9 @@ from typing import Any, Union from requests import Response, post -from model_providers.core.model_runtime.entities.message_entities import PromptMessageTool +from model_providers.core.model_runtime.entities.message_entities import ( + PromptMessageTool, +) from model_providers.core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( BadRequestError, InternalServerError, @@ -17,9 +19,10 @@ from model_providers.core.model_runtime.model_providers.wenxin.llm.ernie_bot_err ) # map api_key to access_token -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens: dict[str, "BaiduAccessToken"] = {} baidu_access_tokens_lock = Lock() + class BaiduAccessToken: api_key: str access_token: str @@ -27,48 +30,56 @@ class BaiduAccessToken: def __init__(self, api_key: str) -> None: self.api_key = api_key - self.access_token = '' + self.access_token = "" self.expires = datetime.now() + timedelta(days=3) def _get_access_token(api_key: str, secret_key: str) -> str: """ - request access token from Baidu + request access token from Baidu """ try: response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', + url=f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}", headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' + "Content-Type": "application/json", + "Accept": "application/json", }, ) except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') + raise InvalidAuthenticationError( + f"Failed to get access token from Baidu: {e}" + ) resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': - raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': - raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': + if "error" in resp: + if resp["error"] == "invalid_client": + raise InvalidAPIKeyError( + f'Invalid API key or secret key: {resp["error_description"]}' + ) + elif resp["error"] == "unknown_error": + raise InternalServerError( + f'Internal server error: {resp["error_description"]}' + ) + elif resp["error"] == "invalid_request": raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': - raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') + elif resp["error"] == "rate_limit_exceeded": + raise RateLimitReachedError( + f'Rate limit reached: {resp["error_description"]}' + ) else: raise Exception(f'Unknown error: {resp["error_description"]}') - - return resp['access_token'] + + return resp["access_token"] @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': + def get_access_token(api_key: str, secret_key: str) -> "BaiduAccessToken": """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. """ # loop up cache, remove expired access token @@ -97,53 +108,61 @@ class BaiduAccessToken: baidu_access_tokens_lock.release() return token + class ErnieMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - FUNCTION = 'function' - SYSTEM = 'system' + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + SYSTEM = "system" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - - def __init__(self, content: str, role: str = 'user') -> None: + + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role + class ErnieBotModel: api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', + "ernie-bot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-bot-4": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-bot-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k", + "ernie-bot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", } function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', + "ernie-bot", + "ernie-bot-8k", ] - api_key: str = '' - secret_key: str = '' + api_key: str = "" + secret_key: str = "" def __init__(self, api_key: str, secret_key: str): self.api_key = api_key self.secret_key = secret_key - def generate(self, model: str, stream: bool, messages: list[ErnieMessage], - parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ - stop: list[str], user: str) \ - -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: - + def generate( + self, + model: str, + stream: bool, + messages: list[ErnieMessage], + parameters: dict[str, Any], + timeout: int, + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters self._check_parameters(model, parameters, tools, stop) @@ -151,27 +170,34 @@ class ErnieBotModel: access_token = self._get_access_token() # generate request body - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" # clone messages messages_cloned = self._copy_messages(messages=messages) # build body - body = self._build_request_body(model, messages=messages_cloned, stream=stream, - parameters=parameters, tools=tools, stop=stop, user=user) + body = self._build_request_body( + model, + messages=messages_cloned, + stream=stream, + parameters=parameters, + tools=tools, + stop=stop, + user=user, + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url=url, data=dumps(body), headers=headers, stream=stream) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") if stream: return self._handle_chat_stream_generate_response(resp) return self._handle_chat_generate_response(resp) - + def _handle_error(self, code: int, msg: str): error_map = { 1: InternalServerError, @@ -205,26 +231,31 @@ class ErnieBotModel: 336105: BadRequestError, 336200: InternalServerError, 336303: BadRequestError, - 337006: BadRequestError + 337006: BadRequestError, } if code in error_map: raise error_map[code](msg) else: - raise InternalServerError(f'Unknown error: {msg}') + raise InternalServerError(f"Unknown error: {msg}") def _get_access_token(self) -> str: token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) return token.access_token - + def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str]) -> None: + def _check_parameters( + self, + model: str, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + ) -> None: if model not in self.api_bases: - raise BadRequestError(f'Invalid model: {model}') - + raise BadRequestError(f"Invalid model: {model}") + # if model not in self.function_calling_supports and tools is not None and len(tools) > 0: # raise BadRequestError(f'Model {model} does not support calling function.') # ErnieBot supports function calling, however, there is lots of limitations. @@ -232,101 +263,125 @@ class ErnieBotModel: # so, we just disable function calling for now. if tools is not None and len(tools) > 0: - raise BadRequestError('function calling is not supported yet.') + raise BadRequestError("function calling is not supported yet.") if stop is not None: if len(stop) > 4: - raise BadRequestError('stop list should not exceed 4 items.') + raise BadRequestError("stop list should not exceed 4 items.") for s in stop: if len(s) > 20: - raise BadRequestError('stop item should not exceed 20 characters.') - - def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: + raise BadRequestError("stop item should not exceed 20 characters.") + + def _build_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) - return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - - def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], tools: list[PromptMessageTool], - stop: list[str], user: str) \ - -> dict[str, Any]: + return self._build_chat_request_body( + model, messages, stream, parameters, stop, user + ) + + def _build_function_calling_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role == 'function': - raise BadRequestError('The first message should be user message.') - + raise BadRequestError("The number of messages should be odd.") + if messages[0].role == "function": + raise BadRequestError("The first message should be user message.") + """ TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_chat_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) == 0: - raise BadRequestError('The number of messages should not be zero.') - + raise BadRequestError("The number of messages should not be zero.") + # check if the first element is system, shift it - system_message = '' - if messages[0].role == 'system': + system_message = "" + if messages[0].role == "system": message = messages.pop(0) system_message = message.content if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role != 'user': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role != "user": + raise BadRequestError("The first message should be user message.") body = { - 'messages': [message.to_dict() for message in messages], - 'stream': stream, - 'stop': stop, - 'user_id': user, - **parameters + "messages": [message.to_dict() for message in messages], + "stream": stream, + "stop": stop, + "user_id": user, + **parameters, } if system_message: - body['system'] = system_message + body["system"] = system_message return body - + def _handle_chat_generate_response(self, response: Response) -> ErnieMessage: data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - result = data['result'] - usage = data['usage'] + result = data["result"] + usage = data["usage"] - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } return message - def _handle_chat_stream_generate_response(self, response: Response) -> Generator[ErnieMessage, None, None]: + def _handle_chat_stream_generate_response( + self, response: Response + ) -> Generator[ErnieMessage, None, None]: for line in response.iter_lines(): if len(line) == 0: continue - line = line.decode('utf-8') - if line[0] == '{': + line = line.decode("utf-8") + if line[0] == "{": try: data = loads(line) - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') - - if line.startswith('data:'): + raise InternalServerError(f"Failed to parse response: {e}") + + if line.startswith("data:"): line = line[5:].strip() else: continue @@ -336,23 +391,23 @@ class ErnieBotModel: try: data = loads(line) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') - - result = data['result'] - is_end = data['is_end'] + raise InternalServerError(f"Failed to parse response: {e}") + + result = data["result"] + is_end = data["is_end"] if is_end: - usage = data['usage'] - finish_reason = data.get('finish_reason', None) - message = ErnieMessage(content=result, role='assistant') + usage = data["usage"] + finish_reason = data.get("finish_reason", None) + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } message.stop_reason = finish_reason yield message else: - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") yield message diff --git a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py index 67d76b4a..4e56e58d 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalance(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/llm.py index 952dbbbf..1f7a638b 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -2,7 +2,11 @@ from collections.abc import Generator from typing import Optional, Union, cast from model_providers.core.model_runtime.callbacks.base_callback import Callback -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -18,9 +22,17 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.wenxin.llm.ernie_bot import ( + BaiduAccessToken, + ErnieBotModel, + ErnieMessage, +) from model_providers.core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( BadRequestError, InsufficientAccountBalance, @@ -41,78 +53,152 @@ if you are not sure about the structure. You should also complete the text started with ``` but not tell ``` directly. """ -class ErnieBotLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: +class ErnieBotLargeLanguageModel(LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: - response_format = model_parameters['response_format'] + if "response_format" in model_parameters and model_parameters[ + "response_format" + ] in ["JSON", "XML"]: + response_format = model_parameters["response_format"] stop = stop or [] - self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format) - model_parameters.pop('response_format') + self._transform_json_prompts( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + response_format, + ) + model_parameters.pop("response_format") if stream: return self._code_block_mode_stream_processor( model=model, prompt_messages=prompt_messages, - input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + input_generator=self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), ) - - return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + return self._invoke( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) + + def _transform_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts to model prompts """ # check if there is a system message - if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + if len(prompt_messages) > 0 and isinstance( + prompt_messages[0], SystemPromptMessage + ): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", prompt_messages[0].content + ).replace("{{block}}", response_format) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", + f"Please output a valid {response_format} object.", + ).replace("{{block}}", response_format) + ), + ) - if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + if len(prompt_messages) > 0 and isinstance( + prompt_messages[-1], UserPromptMessage + ): # add ```JSON\n to the last message prompt_messages[-1].content += "\n```JSON\n{\n" else: # append a user message - prompt_messages.append(UserPromptMessage( - content="```JSON\n{\n" - )) + prompt_messages.append(UserPromptMessage(content="```JSON\n{\n")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -124,10 +210,10 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -137,41 +223,62 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: BaiduAccessToken._get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: instance = ErnieBotModel( - api_key=credentials['api_key'], - secret_key=credentials['secret_key'], + api_key=credentials["api_key"], + secret_key=credentials["secret_key"], ) - user = user if user else 'ErnieBotDefault' + user = user if user else "ErnieBotDefault" # convert prompt messages to baichuan messages messages = [ ErnieMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages + content=message.content + if isinstance(message.content, str) + else "".join([content.data for content in message.content]), + role=message.role.value, + ) + for message in prompt_messages ] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60, tools=tools, stop=stop, user=user) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + stop=stop, + user=user, + ) if stream: - return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) + return self._handle_chat_generate_stream_response( + model, prompt_messages, credentials, response + ) else: - return self._handle_chat_generate_response(model, prompt_messages, credentials, response) + return self._handle_chat_generate_response( + model, prompt_messages, credentials, response + ) def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ @@ -191,43 +298,57 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): message_dict = {"role": "system", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ErnieMessage) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: ErnieMessage, + ) -> LLMResult: # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=response.content, tool_calls=[]), usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[ErnieMessage, None, None]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[ErnieMessage, None, None], + ) -> Generator: for message in response: if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=message.content, tool_calls=[] ), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason + if message.stop_reason + else None, ), ) else: @@ -237,10 +358,11 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=message.content, - tool_calls=[] + content=message.content, tool_calls=[] ), - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason + if message.stop_reason + else None, ), ) @@ -255,21 +377,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalance, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/wenxin.py b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/wenxin.py index 01e194b5..724a7e3d 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/wenxin.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/wenxin.py @@ -1,11 +1,16 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) + class WenxinProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -20,11 +25,12 @@ class WenxinProvider(ModelProvider): # Use `ernie-bot` model for validate, model_instance.validate_credentials( - model='ernie-bot', - credentials=credentials + model="ernie-bot", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py index 98d2d175..aa139170 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py @@ -13,8 +13,15 @@ from openai import ( RateLimitError, UnprocessableEntityError, ) -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageToolCall, +) +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, +) from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion from xinference_client.client.restful.restful_client import ( @@ -25,7 +32,12 @@ from xinference_client.client.restful.restful_client import ( ) from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -51,8 +63,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) from model_providers.core.model_runtime.model_providers.xinference.xinference_helper import ( XinferenceHelper, XinferenceModelExtraParameter, @@ -61,82 +77,114 @@ from model_providers.core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ - if 'temperature' in model_parameters: - if model_parameters['temperature'] < 0.01: - model_parameters['temperature'] = 0.01 - elif model_parameters['temperature'] > 1.0: - model_parameters['temperature'] = 0.99 + if "temperature" in model_parameters: + if model_parameters["temperature"] < 0.01: + model_parameters["temperature"] = 0.01 + elif model_parameters["temperature"] > 1.0: + model_parameters["temperature"] = 0.99 return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] - ) + server_url=credentials["server_url"], model_uid=credentials["model_uid"] + ), ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials - credentials should be like: - { - 'model_type': 'text-generation', - 'server_url': 'server url', - 'model_uid': 'model uid', - } + credentials should be like: + { + 'model_type': 'text-generation', + 'server_url': 'server url', + 'model_uid': 'model uid', + } """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: - raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - + if ( + "/" in credentials["model_uid"] + or "?" in credentials["model_uid"] + or "#" in credentials["model_uid"] + ): + raise CredentialsValidateFailedError( + "model_uid should not contain /, ?, or #" + ) + extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + server_url=credentials["server_url"], model_uid=credentials["model_uid"] ) - if 'completion_type' not in credentials: - if 'chat' in extra_param.model_ability: - credentials['completion_type'] = 'chat' - elif 'generate' in extra_param.model_ability: - credentials['completion_type'] = 'completion' + if "completion_type" not in credentials: + if "chat" in extra_param.model_ability: + credentials["completion_type"] = "chat" + elif "generate" in extra_param.model_ability: + credentials["completion_type"] = "completion" else: - raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') - + raise ValueError( + f"xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type" + ) + if extra_param.support_function_call: - credentials['support_function_call'] = True + credentials["support_function_call"] = True if extra_param.context_length: - credentials['context_length'] = extra_param.context_length + credentials["context_length"] = extra_param.context_length except RuntimeError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError( + f"Xinference credentials validate failed: {e}" + ) except KeyError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError( + f"Xinference credentials validate failed: {e}" + ) except Exception as e: raise e - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + tools: list[PromptMessageTool], + is_completion_model: bool = False, + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -152,9 +200,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': + if isinstance(item, dict) and item["type"] == "text": text += item.text value = text @@ -191,7 +239,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): num_tokens += self._num_tokens_for_tools(tools) return num_tokens - + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for tool calling @@ -200,46 +248,47 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) return num_tokens - + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): text += item.content @@ -248,7 +297,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): elif isinstance(item, AssistantPromptMessage): text += item.content else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError( + f"PromptMessage type {type(item)} is not supported" + ) return text def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: @@ -267,182 +318,215 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} + message_dict = { + "tool_call_id": message.tool_call_id, + "role": "tool", + "content": message.content, + } else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: - raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') + raise ValueError( + f'completion_type {credentials["completion_type"]} is not supported' + ) else: extra_args = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + server_url=credentials["server_url"], model_uid=credentials["model_uid"] ) - if 'chat' in extra_args.model_ability: + if "chat" in extra_args.model_ability: completion_type = LLMMode.CHAT.value - elif 'generate' in extra_args.model_ability: + elif "generate" in extra_args.model_ability: completion_type = LLMMode.COMPLETION.value else: - raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') - - support_function_call = credentials.get('support_function_call', False) - context_length = credentials.get('context_length', 2048) + raise ValueError( + f"xinference model ability {extra_args.model_ability} is not supported" + ) + + support_function_call = credentials.get("support_function_call", False) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - features=[ - ModelFeature.TOOL_CALL - ] if support_function_call else [], - model_properties={ + features=[ModelFeature.TOOL_CALL] if support_function_call else [], + model_properties={ ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length + ModelPropertyKey.CONTEXT_SIZE: context_length, }, - parameter_rules=rules + parameter_rules=rules, ) return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - """ - generate text from LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - - extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + generate text from LLM + + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` + + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + """ + if "server_url" not in credentials: + raise CredentialsValidateFailedError( + "server_url is required in credentials" + ) + + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] client = OpenAI( base_url=f'{credentials["server_url"]}/v1', - api_key='abc', + api_key="abc", max_retries=3, timeout=60, ) xinference_client = Client( - base_url=credentials['server_url'], + base_url=credentials["server_url"], ) - xinference_model = xinference_client.get_model(credentials['model_uid']) + xinference_model = xinference_client.get_model(credentials["model_uid"]) generate_config = { - 'temperature': model_parameters.get('temperature', 1.0), - 'top_p': model_parameters.get('top_p', 0.7), - 'max_tokens': model_parameters.get('max_tokens', 512), + "temperature": model_parameters.get("temperature", 1.0), + "top_p": model_parameters.get("top_p", 0.7), + "max_tokens": model_parameters.get("max_tokens", 512), } if stop: - generate_config['stop'] = stop + generate_config["stop"] = stop if tools and len(tools) > 0: - generate_config['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools + generate_config["tools"] = [ + {"type": "function", "function": helper.dump_model(tool)} + for tool in tools ] - if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): + if isinstance( + xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle + ): resp = client.chat.completions.create( - model=credentials['model_uid'], - messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], + model=credentials["model_uid"], + messages=[ + self._convert_prompt_message_to_dict(message) + for message in prompt_messages + ], stream=stream, user=user, **generate_config, ) if stream: if tools and len(tools) > 0: - raise InvokeBadRequestError('xinference tool calls does not support stream mode') - return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) + raise InvokeBadRequestError( + "xinference tool calls does not support stream mode" + ) + return self._handle_chat_stream_response( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, + resp=resp, + ) + return self._handle_chat_generate_response( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, + resp=resp, + ) elif isinstance(xinference_model, RESTfulGenerateModelHandle): resp = client.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], prompt=self._convert_prompt_message_to_text(prompt_messages), stream=stream, user=user, **generate_config, ) if stream: - return self._handle_completion_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_completion_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=resp) + return self._handle_completion_stream_response( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, + resp=resp, + ) + return self._handle_completion_generate_response( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, + resp=resp, + ) else: - raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') + raise NotImplementedError( + f"xinference model handle type {type(xinference_model)} is not supported" + ) - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, + response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall], + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -454,20 +538,21 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + arguments=response_tool_call.function.arguments, ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call.id, type=response_tool_call.type, - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -478,45 +563,61 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_function_call.name, - arguments=response_function_call.arguments + arguments=response_function_call.arguments, ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: ChatCompletion) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: ChatCompletion, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") - + assistant_message = resp.choices[0].message # convert tool call to assistant message tool call tool_calls = assistant_message.tool_calls - assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) + assistant_prompt_message_tool_calls = self._extract_response_tool_calls( + tool_calls if tool_calls else [] + ) function_call = assistant_message.function_call if function_call: - assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] + assistant_prompt_message_tool_calls += [ + self._extract_response_function_call(function_call) + ] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=assistant_message.content, - tool_calls=assistant_prompt_message_tool_calls + tool_calls=assistant_prompt_message_tool_calls, ) - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) + prompt_tokens = self._num_tokens_from_messages( + messages=prompt_messages, tools=tools + ) + completion_tokens = self._num_tokens_from_messages( + messages=[assistant_prompt_message], tools=tools + ) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) response = LLMResult( model=model, @@ -528,13 +629,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[ChatCompletionChunk]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -542,9 +648,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and ( + delta.delta.content is None or delta.delta.content == "" + ): continue - + # check if there is a tool call in the response function_call = None tool_calls = [] @@ -555,27 +663,36 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): assistant_message_tool_calls = self._extract_response_tool_calls(tool_calls) if function_call: - assistant_message_tool_calls += [self._extract_response_function_call(function_call)] + assistant_message_tool_calls += [ + self._extract_response_function_call(function_call) + ] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content if delta.delta.content else "", + tool_calls=assistant_message_tool_calls, ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) - prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) - completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) + prompt_tokens = self._num_tokens_from_messages( + messages=prompt_messages, tools=tools + ) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -584,7 +701,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -600,21 +717,25 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): full_response += delta.delta.content - def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Completion) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Completion, + ) -> LLMResult: """ - handle normal completion generate response + handle normal completion generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") - + assistant_message = resp.choices[0].text # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] + content=assistant_message, tool_calls=[] ) prompt_tokens = self._get_num_tokens_by_gpt2( @@ -624,7 +745,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): messages=[assistant_prompt_message], tools=[], is_completion_model=True ) usage = self._calc_response_usage( - model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, ) response = LLMResult( @@ -637,13 +761,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return response - def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[Completion]) -> Generator: + def _handle_completion_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[Completion], + ) -> Generator: """ - handle stream completion generate response + handle stream completion generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -653,26 +782,30 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] + content=delta.text if delta.text else "", tool_calls=[] ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] + content=full_response, tool_calls=[] ) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_text(prompt_messages) ) completion_tokens = self._num_tokens_from_messages( - messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True + messages=[temp_assistant_prompt_message], + tools=[], + is_completion_model=True, ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -681,11 +814,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: - if delta.text is None or delta.text == '': + if delta.text is None or delta.text == "": continue yield LLMResultChunk( @@ -720,15 +853,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError + PermissionDeniedError, ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/rerank/rerank.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/rerank/rerank.py index 6d9fdb0f..4291d3c1 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -1,10 +1,20 @@ from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulRerankModelHandle, +) from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from model_providers.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelType, +) +from model_providers.core.model_runtime.entities.rerank_entities import ( + RerankDocument, + RerankResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -13,8 +23,12 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.rerank_model import ( + RerankModel, +) class XinferenceRerankModel(RerankModel): @@ -22,10 +36,16 @@ class XinferenceRerankModel(RerankModel): Model class for Xinference rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -39,23 +59,20 @@ class XinferenceRerankModel(RerankModel): :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=[] - ) + return RerankResult(model=model, docs=[]) - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + if credentials["server_url"].endswith("/"): + credentials["server_url"] = credentials["server_url"][:-1] # initialize client - client = Client( - base_url=credentials['server_url'] - ) + client = Client(base_url=credentials["server_url"]) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulRerankModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model') + raise InvokeBadRequestError( + "please check model type, the model you want to invoke is not a rerank model" + ) response = xinference_client.rerank( documents=docs, @@ -64,27 +81,24 @@ class XinferenceRerankModel(RerankModel): ) rerank_documents = [] - for idx, result in enumerate(response['results']): + for idx, result in enumerate(response["results"]): # format document - index = result['index'] - page_content = result['document'] + index = result["index"] + page_content = result["document"] rerank_document = RerankDocument( index=index, text=page_content, - score=result['relevance_score'], + score=result["relevance_score"], ) # score threshold check if score_threshold is not None: - if result['relevance_score'] >= score_threshold: + if result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -95,19 +109,25 @@ class XinferenceRerankModel(RerankModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: - raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - + if ( + "/" in credentials["model_uid"] + or "?" in credentials["model_uid"] + or "#" in credentials["model_uid"] + ): + raise CredentialsValidateFailedError( + "model_uid should not contain /, ?, or #" + ) + self.invoke( model=model, credentials=credentials, query="Whose kasumi", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -123,38 +143,26 @@ class XinferenceRerankModel(RerankModel): :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index d56047a5..2cabf59b 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -1,11 +1,23 @@ import time from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulEmbeddingModelHandle, +) from model_providers.core.model_runtime.entities.common_entities import I18nObject -from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceType, +) +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) from model_providers.core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -14,18 +26,29 @@ from model_providers.core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.xinference.xinference_helper import ( + XinferenceHelper, +) class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ Model class for Xinference text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -41,27 +64,29 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] - if server_url.endswith('/'): + if server_url.endswith("/"): server_url = server_url[:-1] client = Client(base_url=server_url) - + try: handle = client.get_model(model_uid=model_uid) except RuntimeError as e: raise InvokeAuthorizationError(e) if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') + raise InvokeBadRequestError( + "please check model type, the model you want to invoke is not a text embedding model" + ) try: embeddings = handle.create_embedding(input=texts) except RuntimeError as e: raise InvokeServerUnavailableError(e) - + """ for convenience, the response json is like: class Embedding(TypedDict): @@ -78,13 +103,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): embedding: List[float] """ - usage = embeddings['usage'] - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = embeddings["usage"] + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=usage["total_tokens"] + ) result = TextEmbeddingResult( model=model, - embeddings=[embedding['embedding'] for embedding in embeddings['data']], - usage=usage + embeddings=[embedding["embedding"] for embedding in embeddings["data"]], + usage=usage, ) return result @@ -113,43 +140,45 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: - raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + if ( + "/" in credentials["model_uid"] + or "?" in credentials["model_uid"] + or "#" in credentials["model_uid"] + ): + raise CredentialsValidateFailedError( + "model_uid should not contain /, ?, or #" + ) + + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=server_url, model_uid=model_uid + ) if extra_args.max_tokens: - credentials['max_tokens'] = extra_args.max_tokens + credentials["max_tokens"] = extra_args.max_tokens - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError as e: - raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') + raise CredentialsValidateFailedError( + f"Failed to validate credentials for model {model}: {e}" + ) except RuntimeError as e: raise CredentialsValidateFailedError(e) @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -163,7 +192,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -174,28 +203,30 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity | None: """ - used to define customizable model schema + used to define customizable model schema """ - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.MAX_CHUNKS: 1, - ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + ModelPropertyKey.CONTEXT_SIZE: "max_tokens" in credentials + and credentials["max_tokens"] + or 512, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference.py index 20fc4a5a..9c140dab 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference.py @@ -1,6 +1,8 @@ import logging -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference_helper.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference_helper.py index 66dab658..6194a0cb 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -15,8 +15,15 @@ class XinferenceModelExtraParameter: context_length: int = 2048 support_function_call: bool = False - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, max_tokens: int, context_length: int) -> None: + def __init__( + self, + model_format: str, + model_handle_type: str, + model_ability: list[str], + support_function_call: bool, + max_tokens: int, + context_length: int, + ) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability @@ -24,80 +31,103 @@ class XinferenceModelExtraParameter: self.max_tokens = max_tokens self.context_length = context_length + cache = {} cache_lock = Lock() + class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter( + server_url: str, model_uid: str + ) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: cache[model_uid] = { - 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) + "expires": time() + 300, + "value": XinferenceHelper._get_xinference_extra_parameter( + server_url, model_uid + ), } - return cache[model_uid]['value'] + return cache[model_uid]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [ + model_uid + for model_uid, model in cache.items() + if model["expires"] < time() + ] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter( + server_url: str, model_uid: str + ) -> XinferenceModelExtraParameter: """ - get xinference model extra parameter like model_format and model_handle_type + get xinference model extra parameter like model_format and model_handle_type """ - if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): - raise RuntimeError('model_uid is empty') + if ( + not model_uid + or not model_uid.strip() + or not server_url + or not server_url.strip() + ): + raise RuntimeError("model_uid is empty") - url = str(URL(server_url) / 'v1' / 'models' / model_uid) + url = str(URL(server_url) / "v1" / "models" / model_uid) # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) try: response = session.get(url, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError( + f"get xinference model extra parameter failed, url: {url}, error: {e}" + ) if response.status_code != 200: - raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') - + raise RuntimeError( + f"get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}" + ) + response_json = response.json() - model_format = response_json.get('model_format', 'ggmlv3') - model_ability = response_json.get('model_ability', []) + model_format = response_json.get("model_format", "ggmlv3") + model_ability = response_json.get("model_ability", []) - if response_json.get('model_type') == 'embedding': - model_handle_type = 'embedding' - elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: - model_handle_type = 'chatglm' - elif 'generate' in model_ability: - model_handle_type = 'generate' - elif 'chat' in model_ability: - model_handle_type = 'chat' + if response_json.get("model_type") == "embedding": + model_handle_type = "embedding" + elif model_format == "ggmlv3" and "chatglm" in response_json["model_name"]: + model_handle_type = "chatglm" + elif "generate" in model_ability: + model_handle_type = "generate" + elif "chat" in model_ability: + model_handle_type = "chat" else: - raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') - - support_function_call = 'tools' in model_ability - max_tokens = response_json.get('max_tokens', 512) + raise NotImplementedError( + f"xinference model handle type {model_handle_type} is not supported" + ) + + support_function_call = "tools" in model_ability + max_tokens = response_json.get("max_tokens", 512) + + context_length = response_json.get("context_length", 2048) - context_length = response_json.get('context_length', 2048) - return XinferenceModelExtraParameter( model_format=model_format, model_handle_type=model_handle_type, model_ability=model_ability, support_function_call=support_function_call, max_tokens=max_tokens, - context_length=context_length - ) \ No newline at end of file + context_length=context_length, + ) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py index 840af233..6f6595ed 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py @@ -17,8 +17,11 @@ class _CommonZhipuaiAI: :return: """ credentials_kwargs = { - "api_key": credentials['api_key'] if 'api_key' in credentials else - credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None, + "api_key": credentials["api_key"] + if "api_key" in credentials + else credentials["zhipuai_api_key"] + if "zhipuai_api_key" in credentials + else None, } return credentials_kwargs @@ -38,5 +41,5 @@ class _CommonZhipuaiAI: InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py index 8e64f283..b3721cd0 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,7 +1,11 @@ from collections.abc import Generator from typing import Optional, Union -from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from model_providers.core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from model_providers.core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -11,12 +15,24 @@ from model_providers.core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI -from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI -from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion -from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.large_language_model import ( + LargeLanguageModel, +) +from model_providers.core.model_runtime.model_providers.zhipuai._common import ( + _CommonZhipuaiAI, +) +from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ( + ZhipuAI, +) +from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import ( + Completion, +) +from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, +) from model_providers.core.model_runtime.utils import helper GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object. @@ -31,13 +47,19 @@ And you should always end the block with a "```" to indicate the end of the JSON ```JSON""" -class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: +class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -57,11 +79,20 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # invoke model # stop = stop or [] # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) + return self._generate( + model, + credentials_kwargs, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) - # def _transform_json_prompts(self, model: str, credentials: dict, - # prompt_messages: list[PromptMessage], model_parameters: dict, - # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, # stream: bool = True, user: str | None = None) \ # -> None: # """ @@ -91,8 +122,13 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # content="```JSON\n" # )) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -127,16 +163,22 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): "temperature": 0.5, }, tools=[], - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials_kwargs: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials_kwargs: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -151,15 +193,13 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) if len(prompt_messages) == 0: - raise ValueError('At least one message is required') - + raise ValueError("At least one message is required") + if prompt_messages[0].role == PromptMessageRole.SYSTEM: if not prompt_messages[0].content: prompt_messages = prompt_messages[1:] @@ -168,13 +208,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() - if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: + if copy_prompt_message.role in [ + PromptMessageRole.USER, + PromptMessageRole.SYSTEM, + PromptMessageRole.TOOL, + ]: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model != 'glm-4v': + if model != "glm-4v": # not support list message continue - # get image and + # get image and if not isinstance(copy_prompt_message, UserPromptMessage): # not support system message continue @@ -184,134 +228,157 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # not support image message continue - if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ - copy_prompt_message.role == PromptMessageRole.USER: - new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content + if ( + new_prompt_messages + and new_prompt_messages[-1].role == PromptMessageRole.USER + and copy_prompt_message.role == PromptMessageRole.USER + ): + new_prompt_messages[-1].content += ( + "\n\n" + copy_prompt_message.content + ) else: if copy_prompt_message.role == PromptMessageRole.USER: new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.TOOL: new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.SYSTEM: - new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) + new_prompt_message = SystemPromptMessage( + content=copy_prompt_message.content + ) new_prompt_messages.append(new_prompt_message) else: - new_prompt_message = UserPromptMessage(content=copy_prompt_message.content) + new_prompt_message = UserPromptMessage( + content=copy_prompt_message.content + ) new_prompt_messages.append(new_prompt_message) else: - if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.ASSISTANT: - new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content + if ( + new_prompt_messages + and new_prompt_messages[-1].role == PromptMessageRole.ASSISTANT + ): + new_prompt_messages[-1].content += ( + "\n\n" + copy_prompt_message.content + ) else: new_prompt_messages.append(copy_prompt_message) - if model == 'glm-4v': + if model == "glm-4v": params = { - 'model': model, - 'messages': [{ - 'role': prompt_message.role.value, - 'content': - [ - { - 'type': 'text', - 'text': prompt_message.content - } - ] if isinstance(prompt_message.content, str) else - [ - { - 'type': 'image', - 'image_url': { - 'url': content.data - } - } if content.type == PromptMessageContentType.IMAGE else { - 'type': 'text', - 'text': content.data - } for content in prompt_message.content + "model": model, + "messages": [ + { + "role": prompt_message.role.value, + "content": [{"type": "text", "text": prompt_message.content}] + if isinstance(prompt_message.content, str) + else [ + {"type": "image", "image_url": {"url": content.data}} + if content.type == PromptMessageContentType.IMAGE + else {"type": "text", "text": content.data} + for content in prompt_message.content ], - } for prompt_message in new_prompt_messages], - **model_parameters + } + for prompt_message in new_prompt_messages + ], + **model_parameters, } else: - params = { - 'model': model, - 'messages': [], - **model_parameters - } + params = {"model": model, "messages": [], **model_parameters} # glm model - if not model.startswith('chatglm'): - + if not model.startswith("chatglm"): for prompt_message in new_prompt_messages: if prompt_message.role == PromptMessageRole.TOOL: - params['messages'].append({ - 'role': 'tool', - 'content': prompt_message.content, - 'tool_call_id': prompt_message.tool_call_id - }) + params["messages"].append( + { + "role": "tool", + "content": prompt_message.content, + "tool_call_id": prompt_message.tool_call_id, + } + ) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content, - 'tool_calls': [ - { - 'id': tool_call.id, - 'type': tool_call.type, - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments + params["messages"].append( + { + "role": "assistant", + "content": prompt_message.content, + "tool_calls": [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls - ] - }) + for tool_call in prompt_message.tool_calls + ], + } + ) else: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content - }) + params["messages"].append( + {"role": "assistant", "content": prompt_message.content} + ) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + { + "role": prompt_message.role.value, + "content": prompt_message.content, + } + ) else: # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if prompt_message.role == PromptMessageRole.SYSTEM or \ - prompt_message.role == PromptMessageRole.TOOL or \ - prompt_message.role == PromptMessageRole.USER: - if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': - params['messages'][-1]['content'] += "\n\n" + prompt_message.content + if ( + prompt_message.role == PromptMessageRole.SYSTEM + or prompt_message.role == PromptMessageRole.TOOL + or prompt_message.role == PromptMessageRole.USER + ): + if ( + len(params["messages"]) > 0 + and params["messages"][-1]["role"] == "user" + ): + params["messages"][-1]["content"] += ( + "\n\n" + prompt_message.content + ) else: - params['messages'].append({ - 'role': 'user', - 'content': prompt_message.content - }) + params["messages"].append( + {"role": "user", "content": prompt_message.content} + ) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + { + "role": prompt_message.role.value, + "content": prompt_message.content, + } + ) if tools and len(tools) > 0: - params['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools + params["tools"] = [ + {"type": "function", "function": helper.dump_model(tool)} + for tool in tools ] if stream: - response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs) - return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) + response = client.chat.completions.create( + stream=stream, **params, **extra_model_kwargs + ) + return self._handle_generate_stream_response( + model, credentials_kwargs, tools, response, prompt_messages + ) response = client.chat.completions.create(**params, **extra_model_kwargs) - return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) - - def _handle_generate_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + return self._handle_generate_response( + model, credentials_kwargs, tools, response, prompt_messages + ) + + def _handle_generate_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + response: Completion, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -320,12 +387,12 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - text = '' + text = "" assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -333,36 +400,40 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) - text += choice.message.content or '' - + text += choice.message.content or "" + prompt_usage = response.usage.prompt_tokens completion_usage = response.usage.completion_tokens # transform usage - usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage) + usage = self._calc_response_usage( + model, credentials, prompt_usage, completion_usage + ) # transform response result = LLMResult( model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=text, - tool_calls=assistant_tool_calls + content=text, tool_calls=assistant_tool_calls ), usage=usage, ) return result - def _handle_generate_stream_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - responses: Generator[ChatCompletionChunk, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + responses: Generator[ChatCompletionChunk, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -371,19 +442,21 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_assistant_content = '' + full_assistant_content = "" for chunk in responses: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and ( + delta.delta.content is None or delta.delta.content == "" + ): continue - + assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -391,45 +464,47 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_tool_calls + content=delta.delta.content if delta.delta.content else "", + tool_calls=assistant_tool_calls, ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content if delta.delta.content else "" if delta.finish_reason is not None and chunk.usage is not None: completion_tokens = chunk.usage.completion_tokens prompt_tokens = chunk.usage.prompt_tokens # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage( + model, credentials, prompt_tokens, completion_tokens + ) yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -454,8 +529,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return message_text - - def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. @@ -463,8 +541,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): messages = messages.copy() # don't mutate the original list text = "".join( - self._convert_one_message_to_text(message) - for message in messages + self._convert_one_message_to_text(message) for message in messages ) if tools and len(tools) > 0: diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 6d3df778..ca75fe79 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -2,11 +2,22 @@ import time from typing import Optional from model_providers.core.model_runtime.entities.model_entities import PriceType -from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI -from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI +from model_providers.core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from model_providers.core.model_runtime.model_providers.zhipuai._common import ( + _CommonZhipuaiAI, +) +from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ( + ZhipuAI, +) class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): @@ -14,9 +25,13 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Model class for ZhipuAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -27,16 +42,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) return TextEmbeddingResult( embeddings=embeddings, - usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + usage=self._calc_response_usage( + model, credentials_kwargs, embedding_used_tokens + ), + model=model, ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -68,20 +83,20 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) # call embedding model self.embed_documents( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tuple[list[list[float]], int]: + def embed_documents( + self, model: str, client: ZhipuAI, texts: list[str] + ) -> tuple[list[list[float]], int]: """Call out to ZhipuAI's embedding endpoint. Args: @@ -112,7 +127,9 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): """ return self.embed_documents([text])[0] - def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage( + self, model: str, credentials: dict, tokens: int + ) -> EmbeddingUsage: """ Calculate response usage @@ -125,7 +142,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): model=model, credentials=credentials, price_type=PriceType.INPUT, - tokens=tokens + tokens=tokens, ) # transform usage @@ -136,7 +153,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.py index ca53d988..4f356214 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.py @@ -1,8 +1,12 @@ import logging from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from model_providers.core.model_runtime.errors.validate import ( + CredentialsValidateFailedError, +) +from model_providers.core.model_runtime.model_providers.__base.model_provider import ( + ModelProvider, +) logger = logging.getLogger(__name__) @@ -20,11 +24,12 @@ class ZhipuaiProvider(ModelProvider): model_instance = self.get_model_instance(ModelType.LLM) model_instance.validate_credentials( - model='chatglm_turbo', - credentials=credentials + model="chatglm_turbo", credentials=credentials ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) raise ex diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py index 8a687ef4..bf9b093c 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -1,6 +1,14 @@ - from .__version__ import __version__ from ._client import ZhipuAI -from .core._errors import (APIAuthenticationError, APIInternalError, APIReachLimitError, APIRequestFailedError, - APIResponseError, APIResponseValidationError, APIServerFlowExceedError, APIStatusError, - APITimeoutError, ZhipuAIError) +from .core._errors import ( + APIAuthenticationError, + APIInternalError, + APIReachLimitError, + APIRequestFailedError, + APIResponseError, + APIResponseValidationError, + APIServerFlowExceedError, + APIStatusError, + APITimeoutError, + ZhipuAIError, +) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py index eb0ad332..659f38d7 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py @@ -1,2 +1 @@ - -__version__ = 'v2.0.1' \ No newline at end of file +__version__ = "v2.0.1" diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index 29b17463..27173a4d 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -20,14 +20,14 @@ class ZhipuAI(HttpClient): api_key: str def __init__( - self, - *, - api_key: str | None = None, - base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, - max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, - http_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, + http_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: # if api_key is None: # api_key = os.environ.get("ZHIPUAI_API_KEY") @@ -40,6 +40,7 @@ class ZhipuAI(HttpClient): if base_url is None: base_url = "https://open.bigmodel.cn/api/paas/v4" from .__version__ import __version__ + super().__init__( version=__version__, base_url=base_url, @@ -60,9 +61,11 @@ class ZhipuAI(HttpClient): return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} def __del__(self) -> None: - if (not hasattr(self, "_has_custom_http_client") - or not hasattr(self, "close") - or not hasattr(self, "_client")): + if ( + not hasattr(self, "_has_custom_http_client") + or not hasattr(self, "close") + or not hasattr(self, "_client") + ): # if the '__init__' method raised an error, self would not have client attr return diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index dab6dac5..ce5d737e 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -17,25 +17,24 @@ class AsyncCompletions(BaseAPI): def __init__(self, client: ZhipuAI) -> None: super().__init__(client) - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], list[list[int]], None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], list[list[int]], None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> AsyncTaskStatus: _cast_type = AsyncTaskStatus @@ -71,16 +70,13 @@ class AsyncCompletions(BaseAPI): disable_strict_validation: Optional[bool] | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Union[AsyncCompletion, AsyncTaskStatus]: - _cast_type = Union[AsyncCompletion,AsyncTaskStatus] + _cast_type = Union[AsyncCompletion, AsyncTaskStatus] if disable_strict_validation: _cast_type = object return self._get( path=f"/async-result/{id}", cast_type=_cast_type, options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout - ) + extra_headers=extra_headers, timeout=timeout + ), ) - - diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index 5c4ed4d1..ec29f338 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -20,24 +20,24 @@ class Completions(BaseAPI): super().__init__(client) def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], object, None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Completion | StreamResponse[ChatCompletionChunk]: _cast_type = Completion _stream_cls = StreamResponse[ChatCompletionChunk] diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index 35d54592..4da0276a 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -18,16 +18,16 @@ class Embeddings(BaseAPI): super().__init__(client) def create( - self, - *, - input: Union[str, list[str], list[int], list[list[int]]], - model: Union[str], - encoding_format: str | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + input: Union[str, list[str], list[int], list[list[int]]], + model: Union[str], + encoding_format: str | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EmbeddingsResponded: _cast_type = EmbeddingsResponded if disable_strict_validation: diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index 5deb8d08..f48dc4ff 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -17,17 +17,16 @@ __all__ = ["Files"] class Files(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - file: FileTypes, - purpose: str, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + file: FileTypes, + purpose: str, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FileObject: if not is_file_content(file): prefix = f"Expected file input `{file!r}`" @@ -51,14 +50,14 @@ class Files(BaseAPI): ) def list( - self, - *, - purpose: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - after: str | NotGiven = NOT_GIVEN, - order: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + purpose: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + after: str | NotGiven = NOT_GIVEN, + order: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFileObject: return self._get( "/files", diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py index dc54a9ca..dc30bd33 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py @@ -13,4 +13,3 @@ class FineTuning(BaseAPI): def __init__(self, client: "ZhipuAI") -> None: super().__init__(client) self.jobs = Jobs(client) - diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py index b860de19..ecdf455e 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -7,7 +7,12 @@ import httpx from ...core._base_api import BaseAPI from ...core._base_type import NOT_GIVEN, Headers, NotGiven from ...core._http_client import make_user_request_input -from ...types.fine_tuning import FineTuningJob, FineTuningJobEvent, ListOfFineTuningJob, job_create_params +from ...types.fine_tuning import ( + FineTuningJob, + FineTuningJobEvent, + ListOfFineTuningJob, + job_create_params, +) if TYPE_CHECKING: from ..._client import ZhipuAI @@ -16,21 +21,20 @@ __all__ = ["Jobs"] class Jobs(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( - self, - *, - model: str, - training_file: str, - hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, - suffix: Optional[str] | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - validation_file: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + model: str, + training_file: str, + hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, + suffix: Optional[str] | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + validation_file: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._post( "/fine_tuning/jobs", @@ -49,11 +53,11 @@ class Jobs(BaseAPI): ) def retrieve( - self, - fine_tuning_job_id: str, - *, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + fine_tuning_job_id: str, + *, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}", @@ -64,12 +68,12 @@ class Jobs(BaseAPI): ) def list( - self, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + after: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFineTuningJob: return self._get( "/fine_tuning/jobs", @@ -93,7 +97,6 @@ class Jobs(BaseAPI): extra_headers: Headers | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJobEvent: - return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}/events", cast_type=FineTuningJobEvent, diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 3201426d..63325ce5 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -18,19 +18,19 @@ class Images(BaseAPI): super().__init__(client) def generations( - self, - *, - prompt: str, - model: str | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - quality: Optional[str] | NotGiven = NOT_GIVEN, - response_format: Optional[str] | NotGiven = NOT_GIVEN, - size: Optional[str] | NotGiven = NOT_GIVEN, - style: Optional[str] | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + self, + *, + prompt: str, + model: str | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + quality: Optional[str] | NotGiven = NOT_GIVEN, + response_format: Optional[str] | NotGiven = NOT_GIVEN, + size: Optional[str] | NotGiven = NOT_GIVEN, + style: Optional[str] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ImagesResponded: _cast_type = ImagesResponded if disable_strict_validation: diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py index b7cf6bb7..40630556 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py @@ -88,7 +88,9 @@ FileTypes = Union[ FileContent, # file content tuple[str, FileContent], # (filename, file) tuple[str, FileContent, str], # (filename, file , content_type) - tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + tuple[ + str, FileContent, str, Mapping[str, str] + ], # (filename, file , content_type, headers) ] RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]] @@ -100,7 +102,11 @@ HttpxFileTypes = Union[ FileContent, # file content tuple[str, HttpxFileContent], # (filename, file) tuple[str, HttpxFileContent, str], # (filename, file , content_type) - tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + tuple[ + str, HttpxFileContent, str, Mapping[str, str] + ], # (filename, file , content_type, headers) ] -HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]] +HttpxRequestFiles = Union[ + Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]] +] diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py index a2a438b8..1800a3a3 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py @@ -17,7 +17,10 @@ __all__ = [ class ZhipuAIError(Exception): - def __init__(self, message: str, ) -> None: + def __init__( + self, + message: str, + ) -> None: super().__init__(message) @@ -68,15 +71,16 @@ class APIResponseValidationError(APIResponseError): response: httpx.Response def __init__( - self, - response: httpx.Response, - json_data: object | None, *, - message: str | None = None + self, + response: httpx.Response, + json_data: object | None, + *, + message: str | None = None, ) -> None: super().__init__( message=message or "Data returned by API invalid for expected schema.", request=response.request, - json_data=json_data + json_data=json_data, ) self.response = response self.status_code = response.status_code diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py index 0796bfe1..e7fa1ad2 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py @@ -25,7 +25,9 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes: else: return (file[0], file[1], *file[2:]) else: - raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type") + raise TypeError( + f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type" + ) def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: @@ -37,5 +39,7 @@ def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: elif isinstance(files, Sequence): files = [(key, _transform_file(file)) for key, file in files] else: - raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence") + raise TypeError( + f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence" + ) return files diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index e13d2b02..9e968d52 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -9,7 +9,16 @@ import pydantic from httpx import URL, Timeout from . import _errors -from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT +from ._base_type import ( + NOT_GIVEN, + Body, + Data, + Headers, + NotGiven, + Query, + RequestFiles, + ResponseT, +) from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError from ._files import make_httpx_files from ._request_opt import ClientRequestParam, UserRequestInput @@ -46,16 +55,19 @@ class HttpClient: _default_stream_cls: type[StreamResponse[Any]] | None = None def __init__( - self, - *, - version: str, - base_url: URL, - timeout: Union[float, Timeout, None], - custom_httpx_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None, + self, + *, + version: str, + base_url: URL, + timeout: Union[float, Timeout, None], + custom_httpx_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, ) -> None: if timeout is None or isinstance(timeout, NotGiven): - if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: + if ( + custom_httpx_client + and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT + ): timeout = custom_httpx_client.timeout else: timeout = ZHIPUAI_DEFAULT_TIMEOUT @@ -74,7 +86,6 @@ class HttpClient: self._custom_headers = custom_headers or {} def _prepare_url(self, url: str) -> URL: - sub_url = URL(url) if sub_url.is_relative_url: request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") @@ -84,16 +95,15 @@ class HttpClient: @property def _default_headers(self): - return \ - { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", - "ZhipuAI-SDK-Ver": self._version, - "source_type": "zhipu-sdk-python", - "x-request-sdk": "zhipu-sdk-python", - **self._auth_headers, - **self._custom_headers, - } + return { + "Accept": "application/json", + "Content-Type": "application/json; charset=UTF-8", + "ZhipuAI-SDK-Ver": self._version, + "source_type": "zhipu-sdk-python", + "x-request-sdk": "zhipu-sdk-python", + **self._auth_headers, + **self._custom_headers, + } @property def _auth_headers(self): @@ -107,10 +117,7 @@ class HttpClient: return httpx_headers - def _prepare_request( - self, - request_param: ClientRequestParam - ) -> httpx.Request: + def _prepare_request(self, request_param: ClientRequestParam) -> httpx.Request: kwargs: dict[str, Any] = {} json_data = request_param.json_data headers = self._prepare_headers(request_param) @@ -124,7 +131,9 @@ class HttpClient: return self._client.build_request( headers=headers, - timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout, + timeout=self.timeout + if isinstance(request_param.timeout, NotGiven) + else request_param.timeout, method=request_param.method, url=url, json=json_data, @@ -133,7 +142,9 @@ class HttpClient: **kwargs, ) - def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: + def _object_to_formfata( + self, key: str, value: Data | Mapping[object, object] + ) -> list[tuple[str, str]]: items = [] if isinstance(value, Mapping): @@ -162,7 +173,6 @@ class HttpClient: return [(key, str_data)] def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - items = flatten([self._object_to_formfata(k, v) for k, v in data.items()]) serialized: dict[str, object] = {} @@ -173,30 +183,29 @@ class HttpClient: return serialized def _parse_response( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - enable_stream: bool, - request_param: ClientRequestParam, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + enable_stream: bool, + request_param: ClientRequestParam, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> HttpResponse: - http_response = HttpResponse( raw_response=response, cast_type=cast_type, client=self, enable_stream=enable_stream, - stream_cls=stream_cls + stream_cls=stream_cls, ) return http_response.parse() def _process_response_data( - self, - *, - data: object, - cast_type: type[ResponseT], - response: httpx.Response, + self, + *, + data: object, + cast_type: type[ResponseT], + response: httpx.Response, ) -> ResponseT: if data is None: return cast(ResponseT, None) @@ -205,7 +214,9 @@ class HttpClient: if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel): return cast(ResponseT, cast_type.validate(data)) - return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data)) + return cast( + ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data) + ) except pydantic.ValidationError as err: raise APIResponseValidationError(response=response, json_data=data) from err @@ -222,12 +233,12 @@ class HttpClient: self.close() def request( - self, - *, - cast_type: type[ResponseT], - params: ClientRequestParam, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + cast_type: type[ResponseT], + params: ClientRequestParam, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: request = self._prepare_request(params) @@ -256,81 +267,98 @@ class HttpClient: ) def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - enable_stream: bool = False, + self, + path: str, + *, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + enable_stream: bool = False, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="get", url=path, **options) return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream + cast_type=cast_type, params=opts, enable_stream=enable_stream ) def post( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, - **options) + opts = ClientRequestParam.construct( + method="post", + json_data=body, + files=make_httpx_files(files), + url=path, + **options, + ) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, enable_stream=enable_stream, - stream_cls=stream_cls + stream_cls=stream_cls, ) def patch( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT: - opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) + opts = ClientRequestParam.construct( + method="patch", url=path, json_data=body, **options + ) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def put( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), - **options) + opts = ClientRequestParam.construct( + method="put", + url=path, + json_data=body, + files=make_httpx_files(files), + **options, + ) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def delete( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, + self, + path: str, + *, + body: Body | None = None, + cast_type: type[ResponseT], + options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) + opts = ClientRequestParam.construct( + method="delete", url=path, json_data=body, **options + ) return self.request( - cast_type=cast_type, params=opts, + cast_type=cast_type, + params=opts, ) def _make_status_error(self, response) -> APIStatusError: @@ -347,15 +375,17 @@ class HttpClient: elif status_code == 500: return _errors.APIInternalError(message=error_msg, response=response) elif status_code == 503: - return _errors.APIServerFlowExceedError(message=error_msg, response=response) + return _errors.APIServerFlowExceedError( + message=error_msg, response=response + ) return APIStatusError(message=error_msg, response=response) def make_user_request_input( - max_retries: int | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - extra_headers: Headers = None, - query: Query | None = None, + max_retries: int | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + extra_headers: Headers = None, + query: Query | None = None, ) -> UserRequestInput: options: UserRequestInput = {} @@ -364,7 +394,7 @@ def make_user_request_input( if max_retries is not None: options["max_retries"] = max_retries if not isinstance(timeout, NotGiven): - options['timeout'] = timeout + options["timeout"] = timeout if query is not None: options["params"] = query diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index a3f49ba8..7bd5b3e4 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -35,10 +35,10 @@ class ClientRequestParam: @classmethod def construct( # type: ignore - cls, - _fields_set: set[str] | None = None, - **values: Unpack[UserRequestInput], - ) -> ClientRequestParam : + cls, + _fields_set: set[str] | None = None, + **values: Unpack[UserRequestInput], + ) -> ClientRequestParam: kwargs: dict[str, Any] = { key: remove_notgiven_indict(value) for key, value in values.items() } @@ -48,4 +48,3 @@ class ClientRequestParam: return client model_construct = construct - diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 2f831b6f..7addfd8c 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -26,13 +26,13 @@ class HttpResponse(Generic[R]): http_response: httpx.Response def __init__( - self, - *, - raw_response: httpx.Response, - cast_type: type[R], - client: HttpClient, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + self, + *, + raw_response: httpx.Response, + cast_type: type[R], + client: HttpClient, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, ) -> None: self._cast_type = cast_type self._client = client @@ -52,8 +52,8 @@ class HttpResponse(Generic[R]): self._stream_cls( cast_type=cast(type, get_args(self._stream_cls)[0]), response=self.http_response, - client=self._client - ) + client=self._client, + ), ) return self._parsed cast_type = self._cast_type @@ -63,7 +63,9 @@ class HttpResponse(Generic[R]): if cast_type == str: return cast(R, http_response.text) - content_type, *_ = http_response.headers.get("content-type", "application/json").split(";") + content_type, *_ = http_response.headers.get( + "content-type", "application/json" + ).split(";") origin = get_origin(cast_type) or cast_type if content_type != "application/json": if issubclass(origin, pydantic.BaseModel): diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 66afbfd1..ce3b6df6 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -16,16 +16,15 @@ if TYPE_CHECKING: class StreamResponse(Generic[ResponseT]): - response: httpx.Response _cast_type: type[ResponseT] def __init__( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - client: HttpClient, + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + client: HttpClient, ) -> None: self.response = response self._cast_type = cast_type @@ -39,7 +38,6 @@ class StreamResponse(Generic[ResponseT]): yield from self._stream_chunks def __stream__(self) -> Iterator[ResponseT]: - sse_line_parser = SSELineParser() iterator = sse_line_parser.iter_lines(self.response.iter_lines()) @@ -56,18 +54,20 @@ class StreamResponse(Generic[ResponseT]): json_data=data["error"], ) - yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) + yield self._data_process_func( + data=data, cast_type=self._cast_type, response=self.response + ) for sse in iterator: pass class Event: def __init__( - self, - event: str | None = None, - data: str | None = None, - id: str | None = None, - retry: int | None = None + self, + event: str | None = None, + data: str | None = None, + id: str | None = None, + retry: int | None = None, ): self._event = event self._data = data @@ -79,18 +79,23 @@ class Event: return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" @property - def event(self): return self._event + def event(self): + return self._event @property - def data(self): return self._data + def data(self): + return self._data - def json_data(self): return json.loads(self._data) + def json_data(self): + return json.loads(self._data) @property - def id(self): return self._id + def id(self): + return self._id @property - def retry(self): return self._retry + def retry(self): + return self._retry class SSELineParser: @@ -107,18 +112,20 @@ class SSELineParser: def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: for line in lines: - line = line.rstrip('\n') + line = line.rstrip("\n") if not line: - if self._event is None and \ - not self._data and \ - self._id is None and \ - self._retry is None: + if ( + self._event is None + and not self._data + and self._id is None + and self._retry is None + ): continue sse_event = Event( event=self._event, - data='\n'.join(self._data), + data="\n".join(self._data), id=self._id, - retry=self._retry + retry=self._retry, ) self._event = None self._data = [] @@ -134,7 +141,7 @@ class SSELineParser: field, _p, value = line.partition(":") - if value.startswith(' '): + if value.startswith(" "): value = value[1:] if field == "data": self._data.append(value) diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index f22f32d2..a0645b09 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -20,4 +20,4 @@ class AsyncCompletion(BaseModel): model: Optional[str] = None task_status: str choices: list[CompletionChoice] - usage: CompletionUsage \ No newline at end of file + usage: CompletionUsage diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index b2a847c5..4b3a929a 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -41,5 +41,3 @@ class Completion(BaseModel): request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage - - diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py index 917bda75..75f76fe9 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -6,7 +6,6 @@ __all__ = ["FileObject"] class FileObject(BaseModel): - id: Optional[str] = None bytes: Optional[int] = None created_at: Optional[int] = None @@ -18,7 +17,6 @@ class FileObject(BaseModel): class ListOfFileObject(BaseModel): - object: Optional[str] = None data: list[FileObject] has_more: Optional[bool] = None diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 71c00eaf..1d393028 100644 --- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -2,7 +2,7 @@ from typing import Optional, Union from pydantic import BaseModel -__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] +__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] class Error(BaseModel): diff --git a/model-providers/model_providers/core/model_runtime/schema_validators/common_validator.py b/model-providers/model_providers/core/model_runtime/schema_validators/common_validator.py index fc8e4480..8d56fb65 100644 --- a/model-providers/model_providers/core/model_runtime/schema_validators/common_validator.py +++ b/model-providers/model_providers/core/model_runtime/schema_validators/common_validator.py @@ -1,16 +1,21 @@ from typing import Optional -from model_providers.core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType +from model_providers.core.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FormType, +) class CommonValidator: - def _validate_and_filter_credential_form_schemas(self, - credential_form_schemas: list[CredentialFormSchema], - credentials: dict) -> dict: + def _validate_and_filter_credential_form_schemas( + self, credential_form_schemas: list[CredentialFormSchema], credentials: dict + ) -> dict: need_validate_credential_form_schema_map = {} for credential_form_schema in credential_form_schemas: if not credential_form_schema.show_on: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema + need_validate_credential_form_schema_map[ + credential_form_schema.variable + ] = credential_form_schema continue all_show_on_match = True @@ -24,20 +29,25 @@ class CommonValidator: break if all_show_on_match: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema + need_validate_credential_form_schema_map[ + credential_form_schema.variable + ] = credential_form_schema # Iterate over the remaining credential_form_schemas, verify each credential_form_schema validated_credentials = {} for credential_form_schema in need_validate_credential_form_schema_map.values(): # add the value of the credential_form_schema corresponding to it to validated_credentials - result = self._validate_credential_form_schema(credential_form_schema, credentials) + result = self._validate_credential_form_schema( + credential_form_schema, credentials + ) if result: validated_credentials[credential_form_schema.variable] = result return validated_credentials - def _validate_credential_form_schema(self, credential_form_schema: CredentialFormSchema, credentials: dict) \ - -> Optional[str]: + def _validate_credential_form_schema( + self, credential_form_schema: CredentialFormSchema, credentials: dict + ) -> Optional[str]: """ Validate credential form schema @@ -46,10 +56,15 @@ class CommonValidator: :return: validated credential form schema value """ # If the variable does not exist in credentials - if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: + if ( + credential_form_schema.variable not in credentials + or not credentials[credential_form_schema.variable] + ): # If required is True, an exception is thrown if credential_form_schema.required: - raise ValueError(f'Variable {credential_form_schema.variable} is required') + raise ValueError( + f"Variable {credential_form_schema.variable} is required" + ) else: # Get the value of default if credential_form_schema.default: @@ -65,23 +80,33 @@ class CommonValidator: # If max_length=0, no validation is performed if credential_form_schema.max_length: if len(value) > credential_form_schema.max_length: - raise ValueError(f'Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}') + raise ValueError( + f"Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}" + ) # check the type of value if not isinstance(value, str): - raise ValueError(f'Variable {credential_form_schema.variable} should be string') + raise ValueError( + f"Variable {credential_form_schema.variable} should be string" + ) if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: # If the value is in options, no validation is performed if credential_form_schema.options: - if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f'Variable {credential_form_schema.variable} is not in options') + if value not in [ + option.value for option in credential_form_schema.options + ]: + raise ValueError( + f"Variable {credential_form_schema.variable} is not in options" + ) if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ['true', 'false']: - raise ValueError(f'Variable {credential_form_schema.variable} should be true or false') + if value.lower() not in ["true", "false"]: + raise ValueError( + f"Variable {credential_form_schema.variable} should be true or false" + ) - value = True if value.lower() == 'true' else False + value = True if value.lower() == "true" else False return value diff --git a/model-providers/model_providers/core/model_runtime/schema_validators/model_credential_schema_validator.py b/model-providers/model_providers/core/model_runtime/schema_validators/model_credential_schema_validator.py index f65c56a1..0a92af32 100644 --- a/model-providers/model_providers/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/model-providers/model_providers/core/model_runtime/schema_validators/model_credential_schema_validator.py @@ -1,11 +1,16 @@ from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.entities.provider_entities import ModelCredentialSchema -from model_providers.core.model_runtime.schema_validators.common_validator import CommonValidator +from model_providers.core.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, +) +from model_providers.core.model_runtime.schema_validators.common_validator import ( + CommonValidator, +) class ModelCredentialSchemaValidator(CommonValidator): - - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): + def __init__( + self, model_type: ModelType, model_credential_schema: ModelCredentialSchema + ): self.model_type = model_type self.model_credential_schema = model_credential_schema @@ -25,4 +30,6 @@ class ModelCredentialSchemaValidator(CommonValidator): credentials["__model_type"] = self.model_type.value - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) + return self._validate_and_filter_credential_form_schemas( + credential_form_schemas, credentials + ) diff --git a/model-providers/model_providers/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/model-providers/model_providers/core/model_runtime/schema_validators/provider_credential_schema_validator.py index 4c121a94..2ec9ce2a 100644 --- a/model-providers/model_providers/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/model-providers/model_providers/core/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -1,9 +1,12 @@ -from model_providers.core.model_runtime.entities.provider_entities import ProviderCredentialSchema -from model_providers.core.model_runtime.schema_validators.common_validator import CommonValidator +from model_providers.core.model_runtime.entities.provider_entities import ( + ProviderCredentialSchema, +) +from model_providers.core.model_runtime.schema_validators.common_validator import ( + CommonValidator, +) class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): self.provider_credential_schema = provider_credential_schema @@ -15,6 +18,10 @@ class ProviderCredentialSchemaValidator(CommonValidator): :return: validated provider credentials """ # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.provider_credential_schema.credential_form_schemas + credential_form_schemas = ( + self.provider_credential_schema.credential_form_schemas + ) - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) + return self._validate_and_filter_credential_form_schemas( + credential_form_schemas, credentials + ) diff --git a/model-providers/model_providers/core/model_runtime/utils/encoders.py b/model-providers/model_providers/core/model_runtime/utils/encoders.py index cf6c98e0..7c98c5e0 100644 --- a/model-providers/model_providers/core/model_runtime/utils/encoders.py +++ b/model-providers/model_providers/core/model_runtime/utils/encoders.py @@ -4,7 +4,14 @@ from collections import defaultdict, deque from collections.abc import Callable from decimal import Decimal from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, +) from pathlib import Path, PurePath from re import Pattern from types import GeneratorType @@ -78,7 +85,7 @@ ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]] + type_encoder_map: dict[Any, Callable[[Any], Any]], ) -> dict[Callable[[Any], Any], tuple[Any, ...]]: encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( tuple @@ -153,7 +160,7 @@ def jsonable_encoder( if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): - return format(obj, 'f') + return format(obj, "f") if isinstance(obj, dict): encoded_dict = {} allowed_keys = set(obj.keys()) diff --git a/model-providers/model_providers/core/model_runtime/utils/helper.py b/model-providers/model_providers/core/model_runtime/utils/helper.py index 09d08fa3..7774868f 100644 --- a/model-providers/model_providers/core/model_runtime/utils/helper.py +++ b/model-providers/model_providers/core/model_runtime/utils/helper.py @@ -3,7 +3,7 @@ from pydantic import BaseModel def dump_model(model: BaseModel) -> dict: - if hasattr(pydantic, 'model_dump'): + if hasattr(pydantic, "model_dump"): return pydantic.model_dump(model) else: return model.dict() diff --git a/model-providers/model_providers/core/provider_manager.py b/model-providers/model_providers/core/provider_manager.py index a2016963..1d86e5c3 100644 --- a/model-providers/model_providers/core/provider_manager.py +++ b/model-providers/model_providers/core/provider_manager.py @@ -5,9 +5,15 @@ from typing import Optional from sqlalchemy.exc import IntegrityError -from model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity -from model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, \ - ProviderModelBundle +from model_providers.core.entities.model_entities import ( + DefaultModelEntity, + DefaultModelProviderEntity, +) +from model_providers.core.entities.provider_configuration import ( + ProviderConfiguration, + ProviderConfigurations, + ProviderModelBundle, +) from model_providers.core.entities.provider_entities import ( CustomConfiguration, CustomModelConfiguration, @@ -27,11 +33,17 @@ class ProviderManager: ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. """ - def __init__(self, - provider_name_to_provider_records_dict: dict, - provider_name_to_provider_model_records_dict: dict) -> None: - self.provider_name_to_provider_records_dict = provider_name_to_provider_records_dict - self.provider_name_to_provider_model_records_dict = provider_name_to_provider_model_records_dict + def __init__( + self, + provider_name_to_provider_records_dict: dict, + provider_name_to_provider_model_records_dict: dict, + ) -> None: + self.provider_name_to_provider_records_dict = ( + provider_name_to_provider_records_dict + ) + self.provider_name_to_provider_model_records_dict = ( + provider_name_to_provider_model_records_dict + ) def get_configurations(self, provider: str) -> ProviderConfigurations: """ @@ -80,24 +92,27 @@ class ProviderManager: for provider_entity in provider_entities: provider_name = provider_entity.provider - provider_credentials = self.provider_name_to_provider_records_dict.get(provider_entity.provider) + provider_credentials = self.provider_name_to_provider_records_dict.get( + provider_entity.provider + ) if not provider_credentials: provider_credentials = {} - provider_model_records = self.provider_name_to_provider_model_records_dict.get(provider_entity.provider) + provider_model_records = ( + self.provider_name_to_provider_model_records_dict.get( + provider_entity.provider + ) + ) if not provider_model_records: provider_model_records = [] # Convert to custom configuration custom_configuration = self._to_custom_configuration( - provider_entity, - provider_credentials, - provider_model_records + provider_entity, provider_credentials, provider_model_records ) provider_configuration = ProviderConfiguration( - provider=provider_entity, - custom_configuration=custom_configuration + provider=provider_entity, custom_configuration=custom_configuration ) provider_configurations[provider_name] = provider_configuration @@ -105,7 +120,9 @@ class ProviderManager: # Return the encapsulated object return provider_configurations - def get_provider_model_bundle(self, provider: str, model_type: ModelType) -> ProviderModelBundle: + def get_provider_model_bundle( + self, provider: str, model_type: ModelType + ) -> ProviderModelBundle: """ Get provider model bundle. :param provider: provider name @@ -125,7 +142,7 @@ class ProviderManager: return ProviderModelBundle( configuration=provider_configuration, provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) def get_default_model(self, model_type: ModelType) -> Optional[DefaultModelEntity]: @@ -142,8 +159,7 @@ class ProviderManager: # get available models from provider_configurations available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True + model_type=model_type, only_active=True ) if available_models: @@ -151,8 +167,8 @@ class ProviderManager: for available_model in available_models: if available_model.model == "gpt-3.5-turbo-1106": default_model = { - 'provider_name': available_model.provider.provider, - 'model_name': available_model.model + "provider_name": available_model.provider.provider, + "model_name": available_model.model, } found = True break @@ -160,11 +176,13 @@ class ProviderManager: if not found: available_model = available_models[0] default_model = { - 'provider_name': available_model.provider.provider, - 'model_name': available_model.model + "provider_name": available_model.provider.provider, + "model_name": available_model.model, } - provider_instance = model_provider_factory.get_provider_instance(default_model.get('provider_name')) + provider_instance = model_provider_factory.get_provider_instance( + default_model.get("provider_name") + ) provider_schema = provider_instance.get_provider_schema() return DefaultModelEntity( @@ -175,14 +193,16 @@ class ProviderManager: label=provider_schema.label, icon_small=provider_schema.icon_small, icon_large=provider_schema.icon_large, - supported_model_types=provider_schema.supported_model_types - ) + supported_model_types=provider_schema.supported_model_types, + ), ) - def _to_custom_configuration(self, - provider_entity: ProviderEntity, - provider_credentials: dict, - provider_model_records: list[dict]) -> CustomConfiguration: + def _to_custom_configuration( + self, + provider_entity: ProviderEntity, + provider_credentials: dict, + provider_model_records: list[dict], + ) -> CustomConfiguration: """ Convert to custom configuration. @@ -194,7 +214,8 @@ class ProviderManager: # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) for variable in provider_credential_secret_variables: @@ -210,38 +231,43 @@ class ProviderManager: # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) # Get custom provider model credentials custom_model_configurations = [] for provider_model_record in provider_model_records: - if not provider_model_record.get('model_credentials'): + if not provider_model_record.get("model_credentials"): continue provider_model_credentials = {} for variable in model_credential_secret_variables: - if variable in provider_model_record.get('model_credentials'): + if variable in provider_model_record.get("model_credentials"): try: - provider_model_credentials[variable] = provider_model_record.get('model_credentials').get( - variable) + provider_model_credentials[ + variable + ] = provider_model_record.get("model_credentials").get(variable) except ValueError: pass custom_model_configurations.append( CustomModelConfiguration( - model=provider_model_record.get('model_name'), - model_type=ModelType.value_of(provider_model_record.get('model_type')), - credentials=provider_model_credentials + model=provider_model_record.get("model_name"), + model_type=ModelType.value_of( + provider_model_record.get("model_type") + ), + credentials=provider_model_credentials, ) ) return CustomConfiguration( - provider=custom_provider_configuration, - models=custom_model_configurations + provider=custom_provider_configuration, models=custom_model_configurations ) - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + def _extract_secret_variables( + self, credential_form_schemas: list[CredentialFormSchema] + ) -> list[str]: """ Extract secret input form variables. diff --git a/model-providers/model_providers/core/utils/generic.py b/model-providers/model_providers/core/utils/generic.py index b93b0c57..06eb2a7a 100644 --- a/model-providers/model_providers/core/utils/generic.py +++ b/model-providers/model_providers/core/utils/generic.py @@ -1,7 +1,6 @@ import json from typing import TYPE_CHECKING, Any, Dict - if TYPE_CHECKING: from pydantic import BaseModel @@ -18,4 +17,3 @@ def jsonify(data: "BaseModel") -> str: return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) except Exception: # pydantic v1 return data.json(exclude_unset=True, ensure_ascii=False) - diff --git a/model-providers/model_providers/core/utils/json_dumps.py b/model-providers/model_providers/core/utils/json_dumps.py index 041615ce..2f39ea5a 100644 --- a/model-providers/model_providers/core/utils/json_dumps.py +++ b/model-providers/model_providers/core/utils/json_dumps.py @@ -1,5 +1,6 @@ -import orjson import os + +import orjson from pydantic import BaseModel diff --git a/model-providers/model_providers/core/utils/position_helper.py b/model-providers/model_providers/core/utils/position_helper.py index e038390e..55fd754c 100644 --- a/model-providers/model_providers/core/utils/position_helper.py +++ b/model-providers/model_providers/core/utils/position_helper.py @@ -8,8 +8,8 @@ import yaml def get_position_map( - folder_path: AnyStr, - file_name: str = '_position.yaml', + folder_path: AnyStr, + file_name: str = "_position.yaml", ) -> dict[str, int]: """ Get the mapping from name to index from a YAML file @@ -22,7 +22,7 @@ def get_position_map( if not os.path.exists(position_file_name): return {} - with open(position_file_name, encoding='utf-8') as f: + with open(position_file_name, encoding="utf-8") as f: positions = yaml.safe_load(f) position_map = {} for index, name in enumerate(positions): @@ -30,14 +30,16 @@ def get_position_map( position_map[name.strip()] = index return position_map except: - logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.') + logging.warning( + f"Failed to load the YAML position file {folder_path}/{file_name}." + ) return {} def sort_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> list[Any]: """ Sort the objects by the position map. @@ -50,13 +52,13 @@ def sort_by_position_map( if not position_map or not data: return data - return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) def sort_to_dict_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> OrderedDict[str, Any]: """ Sort the objects into a ordered dict by the position map. diff --git a/model-providers/model_providers/errors/error.py b/model-providers/model_providers/errors/error.py index 6ac95b39..65085a6b 100644 --- a/model-providers/model_providers/errors/error.py +++ b/model-providers/model_providers/errors/error.py @@ -3,6 +3,7 @@ from typing import Optional class LLMError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -11,6 +12,7 @@ class LLMError(Exception): class LLMBadRequestError(LLMError): """Raised when the LLM returns bad request.""" + description = "Bad Request" @@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception): """ Custom exception raised when the provider token is not initialized. """ + description = "Provider Token Not Init" def __init__(self, *args, **kwargs): @@ -28,6 +31,7 @@ class QuotaExceededError(Exception): """ Custom exception raised when the quota for a provider has been exceeded. """ + description = "Quota Exceeded" @@ -35,4 +39,5 @@ class ModelCurrentlyNotSupportError(Exception): """ Custom exception raised when the model not support """ + description = "Model Currently Not Support" diff --git a/model-providers/model_providers/extensions/ext_redis.py b/model-providers/model_providers/extensions/ext_redis.py index c758ccb7..015706e3 100644 --- a/model-providers/model_providers/extensions/ext_redis.py +++ b/model-providers/model_providers/extensions/ext_redis.py @@ -6,18 +6,21 @@ redis_client = redis.Redis() def init_app(app): connection_class = Connection - if app.config.get('REDIS_USE_SSL', False): + if app.config.get("REDIS_USE_SSL", False): connection_class = SSLConnection - redis_client.connection_pool = redis.ConnectionPool(**{ - 'host': app.config.get('REDIS_HOST', 'localhost'), - 'port': app.config.get('REDIS_PORT', 6379), - 'username': app.config.get('REDIS_USERNAME', None), - 'password': app.config.get('REDIS_PASSWORD', None), - 'db': app.config.get('REDIS_DB', 0), - 'encoding': 'utf-8', - 'encoding_errors': 'strict', - 'decode_responses': False - }, connection_class=connection_class) + redis_client.connection_pool = redis.ConnectionPool( + **{ + "host": app.config.get("REDIS_HOST", "localhost"), + "port": app.config.get("REDIS_PORT", 6379), + "username": app.config.get("REDIS_USERNAME", None), + "password": app.config.get("REDIS_PASSWORD", None), + "db": app.config.get("REDIS_DB", 0), + "encoding": "utf-8", + "encoding_errors": "strict", + "decode_responses": False, + }, + connection_class=connection_class, + ) - app.extensions['redis'] = redis_client + app.extensions["redis"] = redis_client diff --git a/model-providers/model_providers/extensions/ext_storage.py b/model-providers/model_providers/extensions/ext_storage.py index be85290f..854fa824 100644 --- a/model-providers/model_providers/extensions/ext_storage.py +++ b/model-providers/model_providers/extensions/ext_storage.py @@ -16,29 +16,29 @@ class Storage: self.folder = None def init_config(self, config: dict): - self.storage_type = config.get('STORAGE_TYPE') - if self.storage_type == 's3': - self.bucket_name = config.get('S3_BUCKET_NAME') + self.storage_type = config.get("STORAGE_TYPE") + if self.storage_type == "s3": + self.bucket_name = config.get("S3_BUCKET_NAME") self.client = boto3.client( - 's3', - aws_secret_access_key=config.get('S3_SECRET_KEY'), - aws_access_key_id=config.get('S3_ACCESS_KEY'), - endpoint_url=config.get('S3_ENDPOINT'), - region_name=config.get('S3_REGION') + "s3", + aws_secret_access_key=config.get("S3_SECRET_KEY"), + aws_access_key_id=config.get("S3_ACCESS_KEY"), + endpoint_url=config.get("S3_ENDPOINT"), + region_name=config.get("S3_REGION"), ) else: - self.folder = config.get('STORAGE_LOCAL_PATH') + self.folder = config.get("STORAGE_LOCAL_PATH") if not os.path.isabs(self.folder): - self.folder = os.path.join(config.get('root_path'), self.folder) + self.folder = os.path.join(config.get("root_path"), self.folder) def save(self, filename, data): - if self.storage_type == 's3': + if self.storage_type == "s3": self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) else: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename folder = os.path.dirname(filename) os.makedirs(folder, exist_ok=True) @@ -53,20 +53,22 @@ class Storage: return self.load_once(filename) def load_once(self, filename: str) -> bytes: - if self.storage_type == 's3': + if self.storage_type == "s3": try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)[ + "Body" + ].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise else: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -78,22 +80,24 @@ class Storage: def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: - if self.storage_type == 's3': + if self.storage_type == "s3": try: with closing(self.client) as client: - response = client.get_object(Bucket=self.bucket_name, Key=filename) - for chunk in response['Body'].iter_chunks(): + response = client.get_object( + Bucket=self.bucket_name, Key=filename + ) + for chunk in response["Body"].iter_chunks(): yield chunk except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise else: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -105,14 +109,14 @@ class Storage: return generate() def download(self, filename, target_filepath): - if self.storage_type == 's3': + if self.storage_type == "s3": with closing(self.client) as client: client.download_file(self.bucket_name, filename, target_filepath) else: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -120,7 +124,7 @@ class Storage: shutil.copyfile(filename, target_filepath) def exists(self, filename): - if self.storage_type == 's3': + if self.storage_type == "s3": with closing(self.client) as client: try: client.head_object(Bucket=self.bucket_name, Key=filename) @@ -128,14 +132,12 @@ class Storage: except: return False else: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename return os.path.exists(filename) storage = Storage() - - diff --git a/model-providers/scripts/check_imports.py b/model-providers/scripts/check_imports.py new file mode 100644 index 00000000..825bea5b --- /dev/null +++ b/model-providers/scripts/check_imports.py @@ -0,0 +1,22 @@ +import random +import string +import sys +import traceback +from importlib.machinery import SourceFileLoader + +if __name__ == "__main__": + files = sys.argv[1:] + has_failure = False + for file in files: + try: + module_name = "".join( + random.choice(string.ascii_letters) for _ in range(20) + ) + SourceFileLoader(module_name, file).load_module() + except Exception: + has_failure = True + print(file) # noqa: T201 + traceback.print_exc() + print() # noqa: T201 + + sys.exit(1 if has_failure else 0) diff --git a/model-providers/scripts/check_pydantic.sh b/model-providers/scripts/check_pydantic.sh new file mode 100644 index 00000000..06b5bb81 --- /dev/null +++ b/model-providers/scripts/check_pydantic.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# +# This script searches for lines starting with "import pydantic" or "from pydantic" +# in tracked files within a Git repository. +# +# Usage: ./scripts/check_pydantic.sh /path/to/repository + +# Check if a path argument is provided +if [ $# -ne 1 ]; then + echo "Usage: $0 /path/to/repository" + exit 1 +fi + +repository_path="$1" + +# Search for lines matching the pattern within the specified repository +result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') + +# Check if any matching lines were found +if [ -n "$result" ]; then + echo "ERROR: The following lines need to be updated:" + echo "$result" + echo "Please replace the code with an import from langchain_core.pydantic_v1." + echo "For example, replace 'from pydantic import BaseModel'" + echo "with 'from langchain_core.pydantic_v1 import BaseModel'" + exit 1 +fi diff --git a/model-providers/scripts/lint_imports.sh b/model-providers/scripts/lint_imports.sh new file mode 100644 index 00000000..19cf642a --- /dev/null +++ b/model-providers/scripts/lint_imports.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -eu + +# Initialize a variable to keep track of errors +errors=0 + +# make sure not importing from chatchat +git --no-pager grep '^from chatchat\.' . && errors=$((errors+1)) + +# Decide on an exit status based on the errors +if [ "$errors" -gt 0 ]; then + exit 1 +else + exit 0 +fi