* 增加使用说明

* 3.8兼容性配置

* fix

* formater

* 不同平台兼容测试用例

* embedding兼容

* 增加日志信息
This commit is contained in:
glide-the 2024-04-16 16:20:08 +08:00 committed by GitHub
parent 4ce7ce0709
commit 2a33f9d4dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
155 changed files with 1062 additions and 3503 deletions

View File

@ -44,8 +44,8 @@ def init_server(model_platforms_shard: Dict,
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
model_platforms_shard['provider_platforms'] = provider_platforms
boot.serve(logging_conf=logging_conf)
boot.logging_conf(logging_conf=logging_conf)
boot.run()
async def pool_join_thread():
await boot.join()

View File

@ -42,3 +42,39 @@ make format
make format_diff
```
当你对项目的一部分进行了更改,并希望确保更改的部分格式正确,而不影响代码库的其他部分时,这个命令特别有用。
### 开始使用
当项目安装完成,配置这个`model_providers.yaml`文件,即可完成平台加载
> 注意: 在您配置平台之前请确认平台依赖完整例如智谱平台您需要安装智谱sdk `pip install zhipuai`
model_providers包含了不同平台提供的 全局配置`provider_credential`,和模型配置`model_credential`
不同平台所加载的配置有所不同,关于如何配置这个文件
请查看包`model_providers.core.model_runtime.model_providers`下方的平台 `yaml`文件
例如`zhipuai.yaml`,这里给出了`provider_credential_schema`,其中包含了一个变量`api_key`
要加载智谱平台,操作如下
- 安装sdk
```shell
$ pip install zhipuai
```
- 编辑`model_providers.yaml`
```yaml
zhipuai:
provider_credential:
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.2'
```
- `model-providers`可以运行pytest 测试
```shell
poetry run pytest tests/server_unit_test/test_init_server.py
```

View File

@ -27,3 +27,7 @@ xinference:
model_uid: 'chatglm3-6b'
zhipuai:
provider_credential:
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'

View File

@ -16,7 +16,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model-providers",
type=str,
default="D:\\project\\Langchain-Chatchat\\model-providers\\model_providers.yaml",
default="/mnt/d/project/Langchain-Chatchat/model-providers/model_providers.yaml",
help="run model_providers servers",
dest="model_providers",
)
@ -36,7 +36,8 @@ if __name__ == "__main__":
.build()
)
boot.set_app_event(started_event=None)
boot.serve(logging_conf=logging_conf)
boot.logging_conf(logging_conf=logging_conf)
boot.run()
async def pool_join_thread():
await boot.join()

View File

@ -50,7 +50,7 @@ class SystemConfigurationResponse(BaseModel):
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
quota_configurations: List[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
@ -65,8 +65,8 @@ class ProviderResponse(BaseModel):
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
supported_model_types: List[ModelType]
configurate_methods: List[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
preferred_provider_type: ProviderType
@ -114,7 +114,7 @@ class ProviderWithModelsResponse(BaseModel):
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
models: list[ModelResponse]
models: List[ModelResponse]
def __init__(self, **data) -> None:
super().__init__(**data)

View File

@ -18,6 +18,8 @@ from typing import (
cast,
)
import tiktoken
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
@ -67,8 +69,13 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._port = port
self._router = APIRouter()
self._app = FastAPI()
self._logging_conf = None
self._server = None
self._server_thread = None
def logging_conf(self,logging_conf: Optional[dict] = None):
self._logging_conf = logging_conf
@classmethod
def from_config(cls, cfg=None):
host = cfg.get("host", "127.0.0.1")
@ -79,7 +86,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
return cls(host=host, port=port)
def serve(self, logging_conf: Optional[dict] = None):
def run(self):
self._app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@ -125,18 +132,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._app.include_router(self._router)
config = Config(
app=self._app, host=self._host, port=self._port, log_config=logging_conf
app=self._app, host=self._host, port=self._port, log_config=self._logging_conf
)
server = Server(config)
self._server = Server(config)
def run_server():
server.run()
self._server.shutdown_timeout = 2 # 设置为2秒
self._server.run()
self._server_thread = threading.Thread(target=run_server)
self._server_thread.start()
async def join(self):
await self._server_thread.join()
def destroy(self):
logger.info("Shutting down server")
self._server.should_exit = True # 设置退出标志
self._server.shutdown() # 停止服务器
self.join()
def join(self):
self._server_thread.join()
def set_app_event(self, started_event: mp.Event = None):
@self._app.on_event("startup")
@ -159,7 +177,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
async def list_models(self, provider: str, request: Request):
logger.info(f"Received list_models request for provider: {provider}")
# 返回ModelType所有的枚举
llm_models: list[AIModelEntity] = []
llm_models: List[AIModelEntity] = []
for model_type in ModelType.__members__.values():
try:
provider_model_bundle = (
@ -176,7 +194,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
logger.error(e)
# models list[AIModelEntity]转换称List[ModelCard]
# modelsList[AIModelEntity]转换称List[ModelCard]
models_list = [
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
@ -191,16 +209,41 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
)
model_instance = self._provider_manager.get_model_instance(
provider=provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embeddings_request.model,
)
texts = embeddings_request.input
if isinstance(texts, str):
texts = [texts]
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
return await openai_embedding_text(response)
try:
model_instance = self._provider_manager.get_model_instance(
provider=provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embeddings_request.model,
)
# 判断embeddings_request.input是否为list
input = ''
if isinstance(embeddings_request.input, list):
tokens = embeddings_request.input
try:
encoding = tiktoken.encoding_for_model(embeddings_request.model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
for i, token in enumerate(tokens):
text = encoding.decode(token)
input += text
else:
input = embeddings_request.input
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
return await openai_embedding_text(response)
except ValueError as e:
logger.error(f"Error while creating embeddings: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except InvokeError as e:
logger.error(f"Error while creating embeddings: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest
@ -254,9 +297,11 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
else:
return await openai_chat_completion(response)
except ValueError as e:
logger.error(f"Error while creating chat completion: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except InvokeError as e:
logger.error(f"Error while creating chat completion: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@ -273,10 +318,11 @@ def run(
cfg=cfg.get("run_openai_api", {})
)
api.set_app_event(started_event=started_event)
api.serve(logging_conf=logging_conf)
api.logging_conf(logging_conf=logging_conf)
api.run()
async def pool_join_thread():
await api.join()
api.join()
asyncio.run(pool_join_thread())
except SystemExit:

View File

@ -145,7 +145,7 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = 256
stop: Optional[list[str]] = None
stop: Optional[List[str]] = None
stream: Optional[bool] = False
def to_model_parameters_dict(self, *args, **kwargs):

View File

@ -112,7 +112,7 @@ class ProvidersWrapper:
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
providers_with_models: List[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue

View File

@ -21,9 +21,9 @@ class ModelConfigEntity(BaseModel):
model_schema: AIModelEntity
mode: str
provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {}
parameters: dict[str, Any] = {}
stop: list[str] = []
credentials: Dict[str, Any] = {}
parameters: Dict[str, Any] = {}
stop: List[str] = []
class AdvancedChatMessageEntity(BaseModel):
@ -40,7 +40,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
Advanced Chat Prompt Template Entity.
"""
messages: list[AdvancedChatMessageEntity]
messages: List[AdvancedChatMessageEntity]
class AdvancedCompletionPromptTemplateEntity(BaseModel):
@ -102,7 +102,7 @@ class ExternalDataVariableEntity(BaseModel):
variable: str
type: str
config: dict[str, Any] = {}
config: Dict[str, Any] = {}
class DatasetRetrieveConfigEntity(BaseModel):
@ -146,7 +146,7 @@ class DatasetEntity(BaseModel):
Dataset Config Entity.
"""
dataset_ids: list[str]
dataset_ids: List[str]
retrieve_config: DatasetRetrieveConfigEntity
@ -156,7 +156,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
"""
type: str
config: dict[str, Any] = {}
config: Dict[str, Any] = {}
class TextToSpeechEntity(BaseModel):
@ -185,7 +185,7 @@ class AgentToolEntity(BaseModel):
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}
tool_parameters: Dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
@ -234,7 +234,7 @@ class AgentEntity(BaseModel):
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
tools: list[AgentToolEntity] = None
tools: List[AgentToolEntity] = None
max_iteration: int = 5
@ -245,7 +245,7 @@ class AppOrchestrationConfigEntity(BaseModel):
model_config: ModelConfigEntity
prompt_template: PromptTemplateEntity
external_data_variables: list[ExternalDataVariableEntity] = []
external_data_variables: List[ExternalDataVariableEntity] = []
agent: Optional[AgentEntity] = None
# features
@ -319,13 +319,13 @@ class ApplicationGenerateEntity(BaseModel):
app_orchestration_config_entity: AppOrchestrationConfigEntity
conversation_id: Optional[str] = None
inputs: dict[str, str]
inputs: Dict[str, str]
query: Optional[str] = None
files: list[FileObj] = []
files: List[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
extras: Dict[str, Any] = {}

View File

@ -47,12 +47,12 @@ class ImagePromptMessageFile(PromptMessageFile):
class LCHumanMessageWithFiles(HumanMessage):
# content: Union[str, list[Union[str, Dict]]]
# content: Union[str,List[Union[str, Dict]]]
content: str
files: list[PromptMessageFile]
files: List[PromptMessageFile]
def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
def lc_messages_to_prompt_messages(messages: List[BaseMessage]) -> List[PromptMessage]:
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
@ -109,8 +109,8 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe
def prompt_messages_to_lc_messages(
prompt_messages: list[PromptMessage],
) -> list[BaseMessage]:
prompt_messages: List[PromptMessage],
) -> List[BaseMessage]:
messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel
@ -31,7 +31,7 @@ class SimpleModelProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
supported_model_types: List[ModelType]
def __init__(self, provider_entity: ProviderEntity) -> None:
"""
@ -66,7 +66,7 @@ class DefaultModelProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
supported_model_types: List[ModelType]
class DefaultModelEntity(BaseModel):

View File

@ -1,9 +1,8 @@
import datetime
import json
import logging
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
from typing import Dict, Iterator, List, Optional
from pydantic import BaseModel
@ -162,7 +161,7 @@ class ProviderConfiguration(BaseModel):
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
) -> List[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
@ -189,8 +188,8 @@ class ProviderConfiguration(BaseModel):
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_custom_provider_models(
self, model_types: list[ModelType], provider_instance: ModelProvider
) -> list[ModelWithProviderEntity]:
self, model_types: List[ModelType], provider_instance: ModelProvider
) -> List[ModelWithProviderEntity]:
"""
Get custom provider models.
@ -266,7 +265,7 @@ class ProviderConfigurations(BaseModel):
Model class for provider configuration dict.
"""
configurations: dict[str, ProviderConfiguration] = {}
configurations: Dict[str, ProviderConfiguration] = {}
def __init__(self):
super().__init__()
@ -276,7 +275,7 @@ class ProviderConfigurations(BaseModel):
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False,
) -> list[ModelWithProviderEntity]:
) -> List[ModelWithProviderEntity]:
"""
Get available models.
@ -317,7 +316,7 @@ class ProviderConfigurations(BaseModel):
return all_models
def to_list(self) -> list[ProviderConfiguration]:
def to_list(self) -> List[ProviderConfiguration]:
"""
Convert to list.

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel
@ -68,7 +68,7 @@ class QuotaConfiguration(BaseModel):
quota_limit: int
quota_used: int
is_valid: bool
restrict_models: list[RestrictModel] = []
restrict_models: List[RestrictModel] = []
class SystemConfiguration(BaseModel):
@ -78,7 +78,7 @@ class SystemConfiguration(BaseModel):
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
quota_configurations: List[QuotaConfiguration] = []
credentials: Optional[dict] = None
@ -106,4 +106,4 @@ class CustomConfiguration(BaseModel):
"""
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
models: List[CustomModelConfiguration] = []

View File

@ -68,7 +68,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
event = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
retriever_resources: List[dict]
class AnnotationReplyEvent(AppQueueEvent):

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import IO, Optional, Union, cast
from typing import IO, Generator, List, Optional, Union, cast
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.core.model_runtime.callbacks.base_callback import Callback
@ -68,13 +67,13 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -105,7 +104,7 @@ class ModelInstance:
)
def invoke_text_embedding(
self, texts: list[str], user: Optional[str] = None
self, texts: List[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke large language model
@ -125,7 +124,7 @@ class ModelInstance:
def invoke_rerank(
self,
query: str,
docs: list[str],
docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,

View File

@ -69,9 +69,9 @@ Model Runtime 分三层:
在这里我们需要先区分模型参数与模型凭据。
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等这些参数是由用户在前端页面上进行调整的因此需要在后端定义参数的规则以便前端页面进行展示和调整。在DifyRuntime中他们的参数名一般为**model_parameters: dict[str, any]**。
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等这些参数是由用户在前端页面上进行调整的因此需要在后端定义参数的规则以便前端页面进行展示和调整。在DifyRuntime中他们的参数名一般为**model_parameters: Dict[str, any]**。
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中他们的参数名一般为**credentials: dict[str, any]**Provider层的credentials会直接被传递到这一层不需要再单独定义。
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中他们的参数名一般为**credentials: Dict[str, any]**Provider层的credentials会直接被传递到这一层不需要再单独定义。
## 下一步

View File

@ -1,5 +1,5 @@
from abc import ABC
from typing import Optional
from typing import List, Optional
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@ -33,10 +33,10 @@ class Callback(ABC):
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -61,10 +61,10 @@ class Callback(ABC):
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@ -90,10 +90,10 @@ class Callback(ABC):
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -119,10 +119,10 @@ class Callback(ABC):
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:

View File

@ -1,7 +1,7 @@
import json
import logging
import sys
from typing import Optional
from typing import List, Optional
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import (
@ -23,10 +23,10 @@ class LoggingCallback(Callback):
llm_instance: AIModel,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -79,10 +79,10 @@ class LoggingCallback(Callback):
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@ -109,10 +109,10 @@ class LoggingCallback(Callback):
result: LLMResult,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@ -154,10 +154,10 @@ class LoggingCallback(Callback):
ex: Exception,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:

View File

@ -71,7 +71,7 @@ All models need to uniformly implement the following 2 methods:
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@ -95,8 +95,8 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
prompt_messages:List[PromptMessage], model_parameters: dict,
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@ -157,8 +157,8 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@ -196,7 +196,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl
```python
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
texts:List[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
@ -230,7 +230,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl
- Pre-calculating Tokens
```python
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts:List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -251,7 +251,7 @@ Inherit the `__base.rerank_model.RerankModel` base class and implement the follo
```python
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
query: str, docs:List[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
@ -498,7 +498,7 @@ class PromptMessage(ABC, BaseModel):
Model class for prompt message.
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
content: Optional[str |List[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
name: Optional[str] = None
```
@ -539,7 +539,7 @@ class AssistantPromptMessage(PromptMessage):
function: ToolCallFunction # tool call information
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
tool_calls:List[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
```
Where `tool_calls` are the list of `tool calls` returned by the model after invoking the model with the `tools` input.
@ -593,7 +593,7 @@ class LLMResult(BaseModel):
Model class for llm result.
"""
model: str # Actual used modele
prompt_messages: list[PromptMessage] # prompt messages
prompt_messages:List[PromptMessage] # prompt messages
message: AssistantPromptMessage # response message
usage: LLMUsage # usage info
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
@ -624,7 +624,7 @@ class LLMResultChunk(BaseModel):
Model class for llm result chunk.
"""
model: str # Actual used modele
prompt_messages: list[PromptMessage] # prompt messages
prompt_messages:List[PromptMessage] # prompt messages
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
delta: LLMResultChunkDelta
```
@ -660,7 +660,7 @@ class TextEmbeddingResult(BaseModel):
Model class for text embedding result.
"""
model: str # Actual model used
embeddings: list[list[float]] # List of embedding vectors, corresponding to the input texts list
embeddings:List[List[float]] # List of embedding vectors, corresponding to the input texts list
usage: EmbeddingUsage # Usage information
```
@ -690,7 +690,7 @@ class RerankResult(BaseModel):
Model class for rerank result.
"""
model: str # Actual model used
docs: list[RerankDocument] # Reranked document list
docs:List[RerankDocument] # Reranked document list
```
### RerankDocument

View File

@ -160,8 +160,8 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
prompt_messages:List[PromptMessage], model_parameters: dict,
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@ -184,8 +184,8 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@ -226,7 +226,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -126,8 +126,8 @@ provider_credential_schema:
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
prompt_messages:List[PromptMessage], model_parameters: dict,
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@ -166,8 +166,8 @@ provider_credential_schema:
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@ -283,7 +283,7 @@ provider_credential_schema:
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -80,7 +80,7 @@ class XinferenceProvider(Provider):
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@ -95,7 +95,7 @@ class XinferenceProvider(Provider):
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
@ -127,8 +127,8 @@ class XinferenceProvider(Provider):
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
prompt_messages:List[PromptMessage], model_parameters: dict,
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@ -189,8 +189,8 @@ class XinferenceProvider(Provider):
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@ -232,7 +232,7 @@ class XinferenceProvider(Provider):
```python
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
texts:List[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
@ -266,7 +266,7 @@ class XinferenceProvider(Provider):
- 预计算 tokens
```python
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts:List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -289,7 +289,7 @@ class XinferenceProvider(Provider):
```python
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
query: str, docs:List[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
@ -538,7 +538,7 @@ class PromptMessage(ABC, BaseModel):
Model class for prompt message.
"""
role: PromptMessageRole # 消息角色
content: Optional[str | list[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
content: Optional[str |List[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
name: Optional[str] = None # 名称,可选。
```
@ -579,7 +579,7 @@ class AssistantPromptMessage(PromptMessage):
function: ToolCallFunction # 工具调用信息
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools并且模型认为需要调用工具时返回
tool_calls:List[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools并且模型认为需要调用工具时返回
```
其中 `tool_calls` 为调用模型传入 `tools` 后,由模型返回的 `tool call` 列表。
@ -633,7 +633,7 @@ class LLMResult(BaseModel):
Model class for llm result.
"""
model: str # 实际使用模型
prompt_messages: list[PromptMessage] # prompt 消息列表
prompt_messages:List[PromptMessage] # prompt 消息列表
message: AssistantPromptMessage # 回复消息
usage: LLMUsage # 使用的 tokens 及费用信息
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
@ -664,7 +664,7 @@ class LLMResultChunk(BaseModel):
Model class for llm result chunk.
"""
model: str # 实际使用模型
prompt_messages: list[PromptMessage] # prompt 消息列表
prompt_messages:List[PromptMessage] # prompt 消息列表
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
delta: LLMResultChunkDelta # 每个迭代存在变化的内容
```
@ -700,7 +700,7 @@ class TextEmbeddingResult(BaseModel):
Model class for text embedding result.
"""
model: str # 实际使用模型
embeddings: list[list[float]] # embedding 向量列表,对应传入的 texts 列表
embeddings:List[List[float]] # embedding 向量列表,对应传入的 texts 列表
usage: EmbeddingUsage # 使用信息
```
@ -730,7 +730,7 @@ class RerankResult(BaseModel):
Model class for rerank result.
"""
model: str # 实际使用模型
docs: list[RerankDocument] # 重排后的分段列表
docs:List[RerankDocument] # 重排后的分段列表
```
### RerankDocument

View File

@ -76,8 +76,8 @@ pricing: # 价格信息
```python
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
prompt_messages:List[PromptMessage], model_parameters: dict,
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@ -116,8 +116,8 @@ pricing: # 价格信息
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@ -158,7 +158,7 @@ pricing: # 价格信息
```python
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,8 +1,10 @@
from typing import Dict
from model_providers.core.model_runtime.entities.model_entities import (
DefaultParameterName,
)
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
"label": {
"en_US": "Temperature",

View File

@ -1,6 +1,6 @@
from decimal import Decimal
from enum import Enum
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel
@ -78,7 +78,7 @@ class LLMResult(BaseModel):
"""
model: str
prompt_messages: list[PromptMessage]
prompt_messages: List[PromptMessage]
message: AssistantPromptMessage
usage: LLMUsage
system_fingerprint: Optional[str] = None
@ -101,7 +101,7 @@ class LLMResultChunk(BaseModel):
"""
model: str
prompt_messages: list[PromptMessage]
prompt_messages: List[PromptMessage]
system_fingerprint: Optional[str] = None
delta: LLMResultChunkDelta

View File

@ -1,6 +1,6 @@
from abc import ABC
from enum import Enum
from typing import Optional
from typing import List, Optional, Union
from pydantic import BaseModel
@ -110,7 +110,7 @@ class PromptMessage(ABC, BaseModel):
"""
role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
content: Optional[Union[str, List[PromptMessageContent]]] = None
name: Optional[str] = None
@ -145,7 +145,7 @@ class AssistantPromptMessage(PromptMessage):
function: ToolCallFunction
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = []
tool_calls: List[ToolCall] = []
class SystemPromptMessage(PromptMessage):

View File

@ -1,6 +1,6 @@
from decimal import Decimal
from enum import Enum
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
@ -158,9 +158,9 @@ class ProviderModel(BaseModel):
model: str
label: I18nObject
model_type: ModelType
features: Optional[list[ModelFeature]] = None
features: Optional[List[ModelFeature]] = None
fetch_from: FetchFrom
model_properties: dict[ModelPropertyKey, Any]
model_properties: Dict[ModelPropertyKey, Any]
deprecated: bool = False
class Config:
@ -182,7 +182,7 @@ class ParameterRule(BaseModel):
min: Optional[float] = None
max: Optional[float] = None
precision: Optional[int] = None
options: list[str] = []
options: List[str] = []
class PriceConfig(BaseModel):
@ -201,7 +201,7 @@ class AIModelEntity(ProviderModel):
Model class for AI model.
"""
parameter_rules: list[ParameterRule] = []
parameter_rules: List[ParameterRule] = []
pricing: Optional[PriceConfig] = None

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel
@ -48,7 +48,7 @@ class FormOption(BaseModel):
label: I18nObject
value: str
show_on: list[FormShowOnObject] = []
show_on: List[FormShowOnObject] = []
def __init__(self, **data):
super().__init__(**data)
@ -66,10 +66,10 @@ class CredentialFormSchema(BaseModel):
type: FormType
required: bool = True
default: Optional[str] = None
options: Optional[list[FormOption]] = None
options: Optional[List[FormOption]] = None
placeholder: Optional[I18nObject] = None
max_length: int = 0
show_on: list[FormShowOnObject] = []
show_on: List[FormShowOnObject] = []
class ProviderCredentialSchema(BaseModel):
@ -77,7 +77,7 @@ class ProviderCredentialSchema(BaseModel):
Model class for provider credential schema.
"""
credential_form_schemas: list[CredentialFormSchema]
credential_form_schemas: List[CredentialFormSchema]
class FieldModelSchema(BaseModel):
@ -91,7 +91,7 @@ class ModelCredentialSchema(BaseModel):
"""
model: FieldModelSchema
credential_form_schemas: list[CredentialFormSchema]
credential_form_schemas: List[CredentialFormSchema]
class SimpleProviderEntity(BaseModel):
@ -103,8 +103,8 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
models: list[AIModelEntity] = []
supported_model_types: List[ModelType]
models: List[AIModelEntity] = []
class ProviderHelpEntity(BaseModel):
@ -128,9 +128,9 @@ class ProviderEntity(BaseModel):
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
models: list[ProviderModel] = []
supported_model_types: List[ModelType]
configurate_methods: List[ConfigurateMethod]
models: List[ProviderModel] = []
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None

View File

@ -1,3 +1,5 @@
from typing import List
from pydantic import BaseModel
@ -17,4 +19,4 @@ class RerankResult(BaseModel):
"""
model: str
docs: list[RerankDocument]
docs: List[RerankDocument]

View File

@ -1,4 +1,5 @@
from decimal import Decimal
from typing import List
from pydantic import BaseModel
@ -25,5 +26,5 @@ class TextEmbeddingResult(BaseModel):
"""
model: str
embeddings: list[list[float]]
embeddings: List[List[float]]
usage: EmbeddingUsage

View File

@ -1,7 +1,7 @@
import decimal
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import Dict, List, Optional, Type
import yaml
@ -35,7 +35,7 @@ class AIModel(ABC):
"""
model_type: ModelType
model_schemas: list[AIModelEntity] = None
model_schemas: List[AIModelEntity] = None
started_at: float = 0
@abstractmethod
@ -51,7 +51,7 @@ class AIModel(ABC):
@property
@abstractmethod
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@ -133,7 +133,7 @@ class AIModel(ABC):
currency=price_config.currency,
)
def predefined_models(self) -> list[AIModelEntity]:
def predefined_models(self) -> List[AIModelEntity]:
"""
Get all predefined models for given provider.

View File

@ -3,8 +3,7 @@ import os
import re
import time
from abc import abstractmethod
from collections.abc import Generator
from typing import Optional, Union
from typing import Generator, List, Optional, Union
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.callbacks.logging_callback import (
@ -47,13 +46,13 @@ class LargeLanguageModel(AIModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@ -170,13 +169,13 @@ class LargeLanguageModel(AIModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
@ -290,7 +289,7 @@ if you are not sure about the structure.
def _code_block_mode_stream_processor(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None],
) -> Generator[LLMResultChunk, None, None]:
"""
@ -428,13 +427,13 @@ if you are not sure about the structure.
model: str,
result: Generator,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> Generator:
"""
Invoke result generator
@ -498,10 +497,10 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -525,8 +524,8 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -539,7 +538,7 @@ if you are not sure about the structure.
"""
raise NotImplementedError
def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
@ -575,7 +574,7 @@ if you are not sure about the structure.
index += 1
time.sleep(0.01)
def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRule]:
def get_parameter_rules(self, model: str, credentials: dict) -> List[ParameterRule]:
"""
Get parameter rules
@ -658,13 +657,13 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> None:
"""
Trigger before invoke callbacks
@ -706,13 +705,13 @@ if you are not sure about the structure.
chunk: LLMResultChunk,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> None:
"""
Trigger new chunk callbacks
@ -755,13 +754,13 @@ if you are not sure about the structure.
model: str,
result: LLMResult,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> None:
"""
Trigger after invoke callbacks
@ -805,13 +804,13 @@ if you are not sure about the structure.
model: str,
ex: Exception,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> None:
"""
Trigger invoke error callbacks
@ -911,7 +910,7 @@ if you are not sure about the structure.
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
)
elif parameter_rule.type == ParameterType.FLOAT:
if not isinstance(parameter_value, float | int):
if not isinstance(parameter_value, (float, int)):
raise ValueError(
f"Model Parameter {parameter_name} should be float."
)

View File

@ -1,6 +1,7 @@
import importlib
import os
from abc import ABC, abstractmethod
from typing import Dict, List
import yaml
@ -14,7 +15,7 @@ from model_providers.core.model_runtime.model_providers.__base.ai_model import A
class ModelProvider(ABC):
provider_schema: ProviderEntity = None
model_instance_map: dict[str, AIModel] = {}
model_instance_map: Dict[str, AIModel] = {}
@abstractmethod
def validate_provider_credentials(self, credentials: dict) -> None:
@ -65,7 +66,7 @@ class ModelProvider(ABC):
return provider_schema
def models(self, model_type: ModelType) -> list[AIModelEntity]:
def models(self, model_type: ModelType) -> List[AIModelEntity]:
"""
Get all models for given model type

View File

@ -1,6 +1,6 @@
import time
from abc import abstractmethod
from typing import Optional
from typing import List, Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
@ -19,7 +19,7 @@ class RerankModel(AIModel):
model: str,
credentials: dict,
query: str,
docs: list[str],
docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
@ -51,7 +51,7 @@ class RerankModel(AIModel):
model: str,
credentials: dict,
query: str,
docs: list[str],
docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,

View File

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import IO, Optional
from typing import IO, List, Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
@ -19,7 +19,7 @@ class Text2ImageModel(AIModel):
prompt: str,
model_parameters: dict,
user: Optional[str] = None,
) -> list[IO[bytes]]:
) -> List[IO[bytes]]:
"""
Invoke Text2Image model
@ -44,7 +44,7 @@ class Text2ImageModel(AIModel):
prompt: str,
model_parameters: dict,
user: Optional[str] = None,
) -> list[IO[bytes]]:
) -> List[IO[bytes]]:
"""
Invoke Text2Image model

View File

@ -1,6 +1,6 @@
import time
from abc import abstractmethod
from typing import Optional
from typing import List, Optional
from model_providers.core.model_runtime.entities.model_entities import (
ModelPropertyKey,
@ -23,7 +23,7 @@ class TextEmbeddingModel(AIModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -62,7 +62,7 @@ class TextEmbeddingModel(AIModel):
raise NotImplementedError
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages

View File

@ -1,7 +1,6 @@
import base64
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Dict, Generator, List, Optional, Type, Union, cast
import anthropic
import requests
@ -63,10 +62,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -92,9 +91,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -158,13 +157,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
@ -203,12 +202,12 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
user: Union[str, None] = None,
response_format: str = "JSON",
) -> None:
"""
@ -251,8 +250,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -297,7 +296,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: Message,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm chat response
@ -345,7 +344,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm chat stream response
@ -424,8 +423,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return credentials_kwargs
def _convert_prompt_messages(
self, prompt_messages: list[PromptMessage]
) -> tuple[str, list[dict]]:
self, prompt_messages: List[PromptMessage]
) -> tuple[str, List[dict]]:
"""
Convert prompt messages to dict list and system
"""
@ -560,7 +559,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt_anthropic(
self, messages: list[PromptMessage]
self, messages: List[PromptMessage]
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
@ -583,7 +582,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return text.rstrip()
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,3 +1,5 @@
from typing import Dict, List, Type
import openai
from httpx import Timeout
@ -29,7 +31,7 @@ class _CommonAzureOpenAI:
return credentials_kwargs
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],

View File

@ -1,7 +1,6 @@
import copy
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Generator, List, Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
@ -60,10 +59,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -102,8 +101,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
model_mode = self._get_ai_model_entity(
credentials.get("base_model_name"), model
@ -176,9 +175,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -215,7 +214,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Completion,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
assistant_text = response.choices[0].text
@ -257,7 +256,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[Completion],
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
full_text = ""
for chunk in response:
@ -321,10 +320,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -381,8 +380,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> LLMResult:
assistant_message = response.choices[0].message
# assistant_message_tool_calls = assistant_message.tool_calls
@ -435,8 +434,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> Generator:
index = 0
full_assistant_content = ""
@ -545,8 +544,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
@staticmethod
def _extract_response_tool_calls(
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
) -> list[AssistantPromptMessage.ToolCall]:
response_tool_calls: List[
Union[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
],
) -> List[AssistantPromptMessage.ToolCall]:
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
@ -566,7 +567,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
@staticmethod
def _extract_response_function_call(
response_function_call: FunctionCall | ChoiceDeltaFunctionCall,
response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall],
) -> AssistantPromptMessage.ToolCall:
tool_call = None
if response_function_call:
@ -651,7 +652,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
credentials: dict,
text: str,
tools: Optional[list[PromptMessageTool]] = None,
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
try:
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
@ -668,8 +669,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _num_tokens_from_messages(
self,
credentials: dict,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
@ -743,7 +744,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
@staticmethod
def _num_tokens_for_tools(
encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
encoding: tiktoken.Encoding, tools: List[PromptMessageTool]
) -> int:
num_tokens = 0
for tool in tools:

View File

@ -1,7 +1,7 @@
import base64
import copy
import time
from typing import Optional, Union
from typing import List, Optional, Union
import numpy as np
import tiktoken
@ -35,7 +35,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
base_model_name = credentials["base_model_name"]
@ -51,7 +51,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
tokens = []
indices = []
used_tokens = 0
@ -81,8 +81,8 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
results: List[List[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
@ -112,7 +112,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
embeddings=embeddings, usage=usage, model=base_model_name
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
if len(texts) == 0:
return 0
@ -168,9 +168,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
def _embedding_invoke(
model: str,
client: AzureOpenAI,
texts: Union[list[str], str],
texts: Union[List[str], str],
extra_model_kwargs: dict,
) -> tuple[list[list[float]], int]:
) -> tuple[List[list[float]], int]:
response = client.embeddings.create(
input=texts,
model=model,

View File

@ -1,8 +1,7 @@
from collections.abc import Generator
from enum import Enum
from hashlib import md5
from json import dumps, loads
from typing import Any, Union
from typing import Any, Dict, Generator, List, Union
from requests import post
@ -25,10 +24,10 @@ class BaichuanMessage:
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
usage: Dict[str, int] = None
stop_reason: str = ""
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"content": self.content,
@ -112,9 +111,9 @@ class BaichuanModel:
self,
model: str,
stream: bool,
messages: list[BaichuanMessage],
parameters: dict[str, Any],
) -> dict[str, Any]:
messages: List[BaichuanMessage],
parameters: Dict[str, Any],
) -> Dict[str, Any]:
if (
model == "baichuan2-turbo"
or model == "baichuan2-turbo-192k"
@ -165,7 +164,7 @@ class BaichuanModel:
else:
raise BadRequestError(f"Unknown model: {model}")
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
def _build_headers(self, model: str, data: Dict[str, Any]) -> Dict[str, Any]:
if (
model == "baichuan2-turbo"
or model == "baichuan2-turbo-192k"
@ -187,8 +186,8 @@ class BaichuanModel:
self,
model: str,
stream: bool,
messages: list[BaichuanMessage],
parameters: dict[str, Any],
messages: List[BaichuanMessage],
parameters: Dict[str, Any],
timeout: int,
) -> Union[Generator, BaichuanMessage]:
if (

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import cast
from typing import Dict, Generator, List, Type, Union, cast
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@ -49,13 +48,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
return self._generate(
model=model,
credentials=credentials,
@ -71,14 +70,14 @@ class BaichuanLarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
prompt_messages: List[PromptMessage],
tools: Union[List[PromptMessageTool], None] = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
messages: List[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model"""
@ -149,13 +148,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
if tools is not None and len(tools) > 0:
raise InvokeBadRequestError("Baichuan model doesn't support tools")
@ -195,7 +194,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: BaichuanMessage,
) -> LLMResult:
@ -216,7 +215,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: Generator[BaichuanMessage, None, None],
) -> Generator:
@ -258,7 +257,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -47,7 +47,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -93,8 +93,8 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
return result
def embedding(
self, model: str, api_key, texts: list[str], user: Optional[str] = None
) -> tuple[list[list[float]], int]:
self, model: str, api_key, texts: List[str], user: Optional[str] = None
) -> tuple[List[list[float]], int]:
"""
Embed given texts
@ -154,7 +154,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
return [data["embedding"] for data in embeddings], usage["total_tokens"]
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -183,7 +183,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid api key")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],

View File

@ -1,7 +1,6 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union
from typing import Dict, Generator, List, Optional, Type, Union
import boto3
from botocore.config import Config
@ -48,10 +47,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -77,8 +76,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None,
messages: Union[List[PromptMessage], str],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -99,7 +98,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(
self, model_prefix: str, messages: list[PromptMessage]
self, model_prefix: str, messages: List[PromptMessage]
) -> str:
"""
Format a list of messages into a full prompt for the Google model
@ -190,7 +189,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt(
self, messages: list[PromptMessage], model_prefix: str
self, messages: List[PromptMessage], model_prefix: str
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
@ -216,9 +215,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
def _create_payload(
self,
model_prefix: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
):
"""
@ -282,9 +281,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -356,7 +355,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -436,7 +435,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -551,7 +550,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
@ -570,7 +569,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
def _map_client_to_invoke_error(
self, error_code: str, error_msg: str
) -> type[InvokeError]:
) -> Type[InvokeError]:
"""
Map client error to invoke error

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 5.4 KiB

View File

@ -1,9 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<mask id="mask0_8587_60212" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="1" y="2" width="23" height="21">
<path d="M23.8 2H1V22.4H23.8V2Z" fill="white"/>
</mask>
<g mask="url(#mask0_8587_60212)">
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.86378 14.4544C3.86378 13.0981 4.67438 11.737 6.25923 10.6634C7.83827 9.59364 10.0864 8.89368 12.6282 8.89368C15.17 8.89368 17.4182 9.59364 18.9972 10.6634C19.7966 11.2049 20.399 11.8196 20.7998 12.4699C21.2873 11.5802 21.4969 10.6351 21.3835 9.69252C21.3759 9.62928 21.3824 9.56766 21.4005 9.5106C21.0758 9.21852 20.7259 8.94624 20.3558 8.69556C18.3272 7.32126 15.5915 6.50964 12.6282 6.50964C9.66497 6.50964 6.92918 7.32126 4.90058 8.69556C2.8778 10.0659 1.45703 12.0812 1.45703 14.4544C1.45703 16.8275 2.8778 18.8428 4.90058 20.2132C6.92918 21.5875 9.66497 22.3991 12.6282 22.3991C15.5915 22.3991 18.3272 21.5875 20.3558 20.2132C22.3786 18.8428 23.7994 16.8275 23.7994 14.4544C23.7994 12.9455 23.225 11.5813 22.2868 10.4355C22.2377 11.4917 21.8621 12.5072 21.238 13.43C21.3409 13.7686 21.3926 14.1116 21.3926 14.4544C21.3926 15.8107 20.582 17.1717 18.9972 18.2453C17.4182 19.3151 15.17 20.015 12.6282 20.015C10.0864 20.015 7.83827 19.3151 6.25923 18.2453C4.67438 17.1717 3.86378 15.8107 3.86378 14.4544Z" fill="#3762FF"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.84445 11.6838C3.20239 13.4885 3.35368 15.1156 4.18868 16.2838C5.02368 17.452 6.52281 18.1339 8.45459 18.1334C10.3826 18.133 12.6296 17.44 14.6939 15.9922C16.7581 14.5444 18.1643 12.6753 18.8052 10.8739C19.4473 9.0692 19.2959 7.44206 18.461 6.27392C17.626 5.10572 16.1269 4.42389 14.1951 4.42431C12.267 4.42475 10.0201 5.11774 7.95575 6.56552C5.89152 8.01332 4.48529 9.8825 3.84445 11.6838ZM1.53559 10.8778C2.36374 8.55002 4.11254 6.28976 6.54117 4.58645C8.96981 2.88312 11.7029 1.99995 14.1945 1.99939C16.6825 1.99884 19.0426 2.8912 20.4589 4.87263C21.8752 6.85406 21.941 9.35564 21.1141 11.6799C20.2859 14.0077 18.5371 16.2679 16.1085 17.9713C13.6798 19.6746 10.9468 20.5578 8.45513 20.5584C5.9672 20.5589 3.60706 19.6665 2.19075 17.6851C0.774446 15.7036 0.708677 13.2021 1.53559 10.8778Z" fill="#1041F3"/>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 2.2 KiB

View File

@ -1,36 +0,0 @@
import logging
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
ModelProvider,
)
logger = logging.getLogger(__name__)
class ChatGLMProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
# Use `chatglm3-6b` model for validate,
model_instance.validate_credentials(
model="chatglm3-6b", credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(
f"{self.get_provider_schema().provider} credentials validate failed"
)
raise ex

View File

@ -1,28 +0,0 @@
provider: chatglm
label:
en_US: ChatGLM
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#F4F7FF"
help:
title:
en_US: Deploy ChatGLM to your local
zh_Hans: 部署您的本地 ChatGLM
url:
en_US: https://github.com/THUDM/ChatGLM3
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_base
label:
en_US: API URL
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 API URL
en_US: Enter your API URL

View File

@ -1,21 +0,0 @@
model: chatglm2-6b-32k
label:
en_US: ChatGLM2-6B-32K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 2000
min: 1
max: 32000

View File

@ -1,21 +0,0 @@
model: chatglm2-6b
label:
en_US: ChatGLM2-6B
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 2000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 256
min: 1
max: 2000

View File

@ -1,22 +0,0 @@
model: chatglm3-6b-32k
label:
en_US: ChatGLM3-6B-32K
model_type: llm
features:
- tool-call
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 32000

View File

@ -1,22 +0,0 @@
model: chatglm3-6b
label:
en_US: ChatGLM3-6B
model_type: llm
features:
- tool-call
- agent-thought
model_properties:
mode: chat
context_size: 8000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 256
min: 1
max: 8000

View File

@ -1,555 +0,0 @@
import logging
from collections.abc import Generator
from os.path import join
from typing import Optional, cast
from httpx import Timeout
from openai import (
APIConnectionError,
APITimeoutError,
AuthenticationError,
ConflictError,
InternalServerError,
NotFoundError,
OpenAI,
PermissionDeniedError,
RateLimitError,
Stream,
UnprocessableEntityError,
)
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from model_providers.core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from model_providers.core.model_runtime.errors.validate import (
CredentialsValidateFailedError,
)
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from model_providers.core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class ChatGLMLargeLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
return self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
prompt_messages=[
UserPromptMessage(content="ping"),
],
model_parameters={
"max_tokens": 16,
},
)
except Exception as e:
raise CredentialsValidateFailedError(str(e))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
APIConnectionError,
APITimeoutError,
],
InvokeServerUnavailableError: [
InternalServerError,
ConflictError,
NotFoundError,
UnprocessableEntityError,
PermissionDeniedError,
],
InvokeRateLimitError: [RateLimitError],
InvokeAuthorizationError: [AuthenticationError],
InvokeBadRequestError: [ValueError],
}
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
self._check_chatglm_parameters(
model=model, model_parameters=model_parameters, tools=tools
)
kwargs = self._to_client_kwargs(credentials)
# init model client
client = OpenAI(**kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
if tools and len(tools) > 0:
extra_model_kwargs["functions"] = [
helper.dump_model(tool) for tool in tools
]
result = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(
model=model,
credentials=credentials,
response=result,
tools=tools,
prompt_messages=prompt_messages,
)
return self._handle_chat_generate_response(
model=model,
credentials=credentials,
response=result,
tools=tools,
prompt_messages=prompt_messages,
)
def _check_chatglm_parameters(
self, model: str, model_parameters: dict, tools: list[PromptMessageTool]
) -> None:
if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0:
raise InvokeBadRequestError("ChatGLM2 does not support function calling")
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for OpenAI Compatibility API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls and len(message.tool_calls) > 0:
message_dict["function_call"] = {
"name": message.tool_calls[0].function.name,
"arguments": message.tool_calls[0].function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def _extract_response_tool_calls(
self, response_function_calls: list[FunctionCall]
) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_function_calls:
for response_tool_call in response_function_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.name, arguments=response_tool_call.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=0, type="function", function=function
)
tool_calls.append(tool_call)
return tool_calls
def _to_client_kwargs(self, credentials: dict) -> dict:
"""
Convert invoke kwargs to client kwargs
:param stream: is stream response
:param model_name: model name
:param credentials: credentials dict
:param model_parameters: model parameters
:return: client kwargs
"""
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": join(credentials["api_base"], "v1"),
}
return client_kwargs
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> Generator:
full_response = ""
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (
delta.delta.content is None or delta.delta.content == ""
):
continue
# check if there is a tool call in the response
function_calls = None
if delta.delta.function_call:
function_calls = [delta.delta.function_call]
assistant_message_tool_calls = self._extract_response_tool_calls(
function_calls if function_calls else []
)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "",
tool_calls=assistant_message_tool_calls,
)
if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
temp_assistant_prompt_message = AssistantPromptMessage(
content=full_response, tool_calls=assistant_message_tool_calls
)
prompt_tokens = self._num_tokens_from_messages(
messages=prompt_messages, tools=tools
)
completion_tokens = self._num_tokens_from_messages(
messages=[temp_assistant_prompt_message], tools=[]
)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
),
)
full_response += delta.delta.content
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> LLMResult:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return: llm response
"""
if len(response.choices) == 0:
raise InvokeServerUnavailableError("Empty response")
assistant_message = response.choices[0].message
# convert function call to tool call
function_calls = assistant_message.function_call
tool_calls = self._extract_response_tool_calls(
[function_calls] if function_calls else []
)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_message.content, tool_calls=tool_calls
)
prompt_tokens = self._num_tokens_from_messages(
messages=prompt_messages, tools=tools
)
completion_tokens = self._num_tokens_from_messages(
messages=[assistant_prompt_message], tools=tools
)
usage = self._calc_response_usage(
model=model,
credentials=credentials,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=response.system_fingerprint,
usage=usage,
message=assistant_prompt_message,
)
return response
def _num_tokens_from_string(
self, text: str, tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
:param model: model name
:param text: prompt text
:param tools: tools for tool calling
:return: number of tokens
"""
num_tokens = self._get_num_tokens_by_gpt2(text)
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer.
it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer,
As a temporary solution we use GPT2 tokenizer instead.
"""
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "function_call":
for t_key, t_value in value.items():
num_tokens += tokens(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += tokens(f_key)
num_tokens += tokens(f_value)
else:
num_tokens += tokens(t_key)
num_tokens += tokens(t_value)
else:
num_tokens += tokens(str(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
:param encoding: encoding
:param tools: tools for tool calling
:return: number of tokens
"""
def tokens(text: str):
return self._get_num_tokens_by_gpt2(text)
num_tokens = 0
for tool in tools:
# calculate num tokens for function object
num_tokens += tokens("name")
num_tokens += tokens(tool.name)
num_tokens += tokens("description")
num_tokens += tokens(tool.description)
parameters = tool.parameters
num_tokens += tokens("parameters")
num_tokens += tokens("type")
num_tokens += tokens(parameters.get("type"))
if "properties" in parameters:
num_tokens += tokens("properties")
for key, value in parameters.get("properties").items():
num_tokens += tokens(key)
for field_key, field_value in value.items():
num_tokens += tokens(field_key)
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += tokens(enum_field)
else:
num_tokens += tokens(field_key)
num_tokens += tokens(str(field_value))
if "required" in parameters:
num_tokens += tokens("required")
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += tokens(required_field)
return num_tokens

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Dict, List, Optional, Type, Union, cast
import cohere
from cohere.responses import Chat, Generations
@ -55,10 +55,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -103,8 +103,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -171,9 +171,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -216,7 +216,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: Generations,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -256,7 +256,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: StreamingGenerations,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -317,9 +317,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -377,8 +377,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: Chat,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
prompt_messages: List[PromptMessage],
stop: Optional[List[str]] = None,
) -> LLMResult:
"""
Handle llm chat response
@ -429,8 +429,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: StreamingChat,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
prompt_messages: List[PromptMessage],
stop: Optional[List[str]] = None,
) -> Generator:
"""
Handle llm chat stream response
@ -517,8 +517,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
index += 1
def _convert_prompt_messages_to_message_and_chat_histories(
self, prompt_messages: list[PromptMessage]
) -> tuple[str, list[dict]]:
self, prompt_messages: List[PromptMessage]
) -> tuple[str, List[dict]]:
"""
Convert prompt messages to message and chat histories
:param prompt_messages: prompt messages
@ -586,7 +586,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return response.length
def _num_tokens_from_messages(
self, model: str, credentials: dict, messages: list[PromptMessage]
self, model: str, credentials: dict, messages: List[PromptMessage]
) -> int:
"""Calculate num tokens Cohere model."""
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
@ -650,7 +650,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, List, Optional, Type
import cohere
@ -32,7 +32,7 @@ class CohereRerankModel(RerankModel):
model: str,
credentials: dict,
query: str,
docs: list[str],
docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
@ -99,7 +99,7 @@ class CohereRerankModel(RerankModel):
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -35,7 +35,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -51,7 +51,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
tokens = []
indices = []
used_tokens = 0
@ -79,8 +79,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
results: List[List[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
@ -105,7 +105,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -161,8 +161,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError(str(ex))
def _embedding_invoke(
self, model: str, credentials: dict, texts: list[str]
) -> tuple[list[list[float]], int]:
self, model: str, credentials: dict, texts: List[str]
) -> tuple[List[list[float]], int]:
"""
Invoke embedding model
@ -216,7 +216,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,7 +1,6 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union
from typing import Dict, Generator, List, Optional, Type, Union
import google.api_core.exceptions as exceptions
import google.generativeai as genai
@ -66,10 +65,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -102,8 +101,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -118,7 +117,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Google model
@ -155,10 +154,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -249,7 +248,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -306,7 +305,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -416,7 +415,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return glm_content
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
@ -472,8 +471,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
tool_call = None
if response_function_call:
from google.protobuf import json_format
if isinstance(response_function_call, FunctionCall):
map_composite_dict = dict(response_function_call.args.items())
function = AssistantPromptMessage.ToolCall.ToolCallFunction(

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import Optional, Union
from typing import Generator, List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (
@ -16,10 +15,10 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:

View File

@ -1,3 +1,5 @@
from typing import Dict, List, Type
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
from model_providers.core.model_runtime.errors.invoke import (
@ -8,5 +10,5 @@ from model_providers.core.model_runtime.errors.invoke import (
class _CommonHuggingfaceHub:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]}

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import Optional, Union
from typing import Generator, List, Optional, Union
from huggingface_hub import InferenceClient
from huggingface_hub.hf_api import HfApi
@ -44,10 +43,10 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -81,8 +80,8 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
@ -161,7 +160,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
return entity
@staticmethod
def _get_customizable_model_parameter_rules() -> list[ParameterRule]:
def _get_customizable_model_parameter_rules() -> List[ParameterRule]:
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(
DefaultParameterName.TEMPERATURE
).copy()
@ -253,7 +252,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
response: Generator,
) -> Generator:
index = -1
@ -300,7 +299,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
response: any,
) -> LLMResult:
if isinstance(response, str):
@ -355,7 +354,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
return model_info.pipeline_tag
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
messages = messages.copy() # don't mutate the original list
text = "".join(

View File

@ -1,6 +1,6 @@
import json
import time
from typing import Optional
from typing import List, Optional
import numpy as np
import requests
@ -35,7 +35,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
@ -62,7 +62,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
embeddings=self._mean_pooling(embeddings), usage=usage, model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
@ -132,12 +132,12 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
return entity
# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
# Returned values are a list of floats, or a list[list[floats]]
# Returned values are a list of floats, or aList[List[floats]]
# (depending on if you sent a string or a list of string,
# and if the automatic reduction, usually mean_pooling for instance was applied for you or not.
# This should be explained on the model's README.)
@staticmethod
def _mean_pooling(embeddings: list) -> list[float]:
def _mean_pooling(embeddings: list) -> List[float]:
# If automatic reduction by giving model, no need to mean_pooling.
# For example one: List[List[float]]
if not isinstance(embeddings[0][0], list):

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, List, Optional, Type
import httpx
@ -32,7 +32,7 @@ class JinaRerankModel(RerankModel):
model: str,
credentials: dict,
query: str,
docs: list[str],
docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
@ -108,7 +108,7 @@ class JinaRerankModel(RerankModel):
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
"""

View File

@ -1,6 +1,6 @@
import time
from json import JSONDecodeError, dumps
from typing import Optional
from typing import Dict, List, Optional, Type
from requests import post
@ -34,7 +34,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
"""
api_base: str = "https://api.jina.ai/v1/embeddings"
models: list[str] = [
models: List[str] = [
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-small-en",
"jina-embeddings-v2-base-zh",
@ -45,7 +45,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -113,7 +113,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -142,7 +142,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid api key")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import cast
from typing import Dict, Generator, List, Type, Union, cast
from httpx import Timeout
from openai import (
@ -64,13 +63,13 @@ class LocalAILarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
return self._generate(
model=model,
credentials=credentials,
@ -86,14 +85,14 @@ class LocalAILarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
prompt_messages: List[PromptMessage],
tools: Union[List[PromptMessageTool], None] = None,
) -> int:
# tools is not supported yet
return self._num_tokens_from_messages(prompt_messages, tools=tools)
def _num_tokens_from_messages(
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
self, messages: List[PromptMessage], tools: List[PromptMessageTool]
) -> int:
"""
Calculate num tokens for baichuan model
@ -156,7 +155,7 @@ class LocalAILarguageModel(LargeLanguageModel):
return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
@ -224,7 +223,7 @@ class LocalAILarguageModel(LargeLanguageModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity | None:
) -> Union[AIModelEntity, None]:
completion_model = None
if credentials["completion_type"] == "chat_completion":
completion_model = LLMMode.CHAT.value
@ -286,13 +285,13 @@ class LocalAILarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
kwargs = self._to_client_kwargs(credentials)
# init model client
client = OpenAI(**kwargs)
@ -414,7 +413,7 @@ class LocalAILarguageModel(LargeLanguageModel):
return message_dict
def _convert_prompt_message_to_completion_prompts(
self, messages: list[PromptMessage]
self, messages: List[PromptMessage]
) -> str:
"""
Convert PromptMessage to completion prompts
@ -438,7 +437,7 @@ class LocalAILarguageModel(LargeLanguageModel):
def _handle_completion_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: Completion,
) -> LLMResult:
@ -489,10 +488,10 @@ class LocalAILarguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: ChatCompletion,
tools: list[PromptMessageTool],
tools: List[PromptMessageTool],
) -> LLMResult:
"""
Handle llm chat response
@ -547,10 +546,10 @@ class LocalAILarguageModel(LargeLanguageModel):
def _handle_completion_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: Stream[Completion],
tools: list[PromptMessageTool],
tools: List[PromptMessageTool],
) -> Generator:
full_response = ""
@ -613,10 +612,10 @@ class LocalAILarguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: Stream[ChatCompletionChunk],
tools: list[PromptMessageTool],
tools: List[PromptMessageTool],
) -> Generator:
full_response = ""
@ -691,8 +690,8 @@ class LocalAILarguageModel(LargeLanguageModel):
full_response += delta.delta.content
def _extract_response_tool_calls(
self, response_function_calls: list[FunctionCall]
) -> list[AssistantPromptMessage.ToolCall]:
self, response_function_calls: List[FunctionCall]
) -> List[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@ -714,7 +713,7 @@ class LocalAILarguageModel(LargeLanguageModel):
return tool_calls
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,6 +1,6 @@
import time
from json import JSONDecodeError, dumps
from typing import Optional
from typing import Dict, List, Optional, Type, Union
from requests import post
from yarl import URL
@ -42,7 +42,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -121,7 +121,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -138,7 +138,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
def _get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity | None:
) -> Union[AIModelEntity, None]:
"""
Get customizable model schema
@ -177,7 +177,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError(f"Invalid credentials: {e}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],

View File

@ -1,6 +1,5 @@
from collections.abc import Generator
from json import dumps, loads
from typing import Any, Union
from typing import Any, Dict, Generator, List, Union
from requests import Response, post
@ -27,10 +26,10 @@ class MinimaxChatCompletion:
model: str,
api_key: str,
group_id: str,
prompt_messages: list[MinimaxMessage],
prompt_messages: List[MinimaxMessage],
model_parameters: dict,
tools: list[dict[str, Any]],
stop: list[str] | None,
tools: List[Dict[str, Any]],
stop: Union[List[str], None],
stream: bool,
user: str,
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:

View File

@ -1,6 +1,5 @@
from collections.abc import Generator
from json import dumps, loads
from typing import Any, Union
from typing import Any, Dict, Generator, List, Union
from requests import Response, post
@ -28,10 +27,10 @@ class MinimaxChatCompletionPro:
model: str,
api_key: str,
group_id: str,
prompt_messages: list[MinimaxMessage],
prompt_messages: List[MinimaxMessage],
model_parameters: dict,
tools: list[dict[str, Any]],
stop: list[str] | None,
tools: List[Dict[str, Any]],
stop: Union[List[str], None],
stream: bool,
user: str,
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:

View File

@ -58,13 +58,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
return self._generate(
model,
credentials,
@ -110,13 +110,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
prompt_messages: List[PromptMessage],
tools: Union[List[PromptMessageTool], None] = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages, tools)
def _num_tokens_from_messages(
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
self, messages: List[PromptMessage], tools: List[PromptMessageTool]
) -> int:
"""
Calculate num tokens for minimax model
@ -137,13 +137,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
"""
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
"""
@ -227,7 +227,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: MinimaxMessage,
) -> LLMResult:
@ -250,7 +250,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: Generator[MinimaxMessage, None, None],
) -> Generator[LLMResultChunk, None, None]:
@ -319,7 +319,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -11,11 +11,11 @@ class MinimaxMessage:
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
usage: Dict[str, int] = None
stop_reason: str = ""
function_call: dict[str, Any] = None
function_call: Dict[str, Any] = None
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
return {
"sender_type": "BOT",

View File

@ -1,6 +1,6 @@
import time
from json import dumps
from typing import Optional
from typing import Dict, List, Optional, Type
from requests import post
@ -44,7 +44,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -103,7 +103,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -146,7 +146,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
raise InternalServerError(msg)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import Optional, Union
from typing import Generator, List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (
@ -16,10 +15,10 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:

View File

@ -1,7 +1,7 @@
import importlib
import logging
import os
from typing import Optional, Union
from typing import Dict, List, Optional, Union
from pydantic import BaseModel
@ -42,7 +42,7 @@ class ModelProviderExtension(BaseModel):
class ModelProviderFactory:
# init cache provider by default
init_cache: bool = False
model_provider_extensions: dict[str, ModelProviderExtension] = None
model_provider_extensions: Dict[str, ModelProviderExtension] = None
def __init__(self, init_cache: bool = False) -> None:
# for cache in memory
@ -51,7 +51,7 @@ class ModelProviderFactory:
def get_providers(
self, provider_name: Union[str, set] = ""
) -> list[ProviderEntity]:
) -> List[ProviderEntity]:
"""
Get all providers
:return: list of providers
@ -159,8 +159,8 @@ class ModelProviderFactory:
self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
provider_configs: Optional[list[ProviderConfig]] = None,
) -> list[SimpleProviderEntity]:
provider_configs: Optional[List[ProviderConfig]] = None,
) -> List[SimpleProviderEntity]:
"""
Get all models for given model type
@ -234,7 +234,7 @@ class ModelProviderFactory:
return model_provider_instance
def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
def _get_model_provider_map(self) -> Dict[str, ModelProviderExtension]:
if self.model_provider_extensions:
return self.model_provider_extensions
@ -254,7 +254,7 @@ class ModelProviderFactory:
position_map = get_position_map(model_providers_path)
# traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []
model_providers: List[ModelProviderExtension] = []
for model_provider_dir_path in model_provider_dir_paths:
# get model_provider dir name
model_provider_name = os.path.basename(model_provider_dir_path)

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Optional, Union
from typing import List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (
@ -16,10 +16,10 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:

View File

@ -3,7 +3,7 @@ import logging
import re
from collections.abc import Generator
from decimal import Decimal
from typing import Optional, Union, cast
from typing import Dict, List, Optional, Type, Union, cast
from urllib.parse import urljoin
import requests
@ -63,10 +63,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -97,8 +97,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -159,9 +159,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -258,7 +258,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict,
completion_type: LLMMode,
response: requests.Response,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm completion response
@ -310,7 +310,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict,
completion_type: LLMMode,
response: requests.Response,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm completion stream response
@ -462,7 +462,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return message_dict
def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
"""
Calculate num tokens.
@ -700,7 +700,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -2,7 +2,7 @@ import json
import logging
import time
from decimal import Decimal
from typing import Optional
from typing import Dict, List, Optional, Type
from urllib.parse import urljoin
import numpy as np
@ -48,7 +48,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -123,7 +123,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
embeddings=batched_embeddings, usage=usage, model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
@ -211,7 +211,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,3 +1,5 @@
from typing import Dict, List, Type
import openai
from httpx import Timeout
@ -35,7 +37,7 @@ class _CommonOpenAI:
return credentials_kwargs
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import List, Optional, Union, cast
import tiktoken
from openai import OpenAI, Stream
@ -72,10 +72,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -128,13 +128,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
@ -194,12 +194,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
user: Union[str, None] = None,
response_format: str = "JSON",
) -> None:
"""
@ -242,12 +242,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
user: Union[str, None] = None,
response_format: str = "JSON",
) -> None:
"""
@ -287,8 +287,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -366,7 +366,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def remote_models(self, credentials: dict) -> list[AIModelEntity]:
def remote_models(self, credentials: dict) -> List[AIModelEntity]:
"""
Return remote models if credentials are provided.
@ -424,9 +424,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -479,7 +479,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Completion,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm completion response
@ -528,7 +528,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[Completion],
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm completion stream response
@ -599,10 +599,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -676,8 +676,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> LLMResult:
"""
Handle llm chat response
@ -740,8 +740,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> Generator:
"""
Handle llm chat stream response
@ -851,8 +851,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
def _extract_response_tool_calls(
self,
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
) -> list[AssistantPromptMessage.ToolCall]:
response_tool_calls: List[
Union[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
],
) -> List[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@ -877,7 +879,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
return tool_calls
def _extract_response_function_call(
self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
self, response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall]
) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
@ -967,7 +969,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
return message_dict
def _num_tokens_from_string(
self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None
self, model: str, text: str, tools: Optional[List[PromptMessageTool]] = None
) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
@ -992,8 +994,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
def _num_tokens_from_messages(
self,
model: str,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
@ -1068,7 +1070,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
return num_tokens
def _num_tokens_for_tools(
self, encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
self, encoding: tiktoken.Encoding, tools: List[PromptMessageTool]
) -> int:
"""
Calculate num tokens for tool calling with tiktoken package.

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional
from openai import OpenAI
from openai.types import ModerationCreateResponse
@ -82,7 +82,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
raise CredentialsValidateFailedError(str(ex))
def _moderation_invoke(
self, model: str, client: OpenAI, texts: list[str]
self, model: str, client: OpenAI, texts: List[str]
) -> ModerationCreateResponse:
"""
Invoke moderation model

View File

@ -1,6 +1,6 @@
import base64
import time
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import tiktoken
@ -31,7 +31,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -58,7 +58,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
tokens = []
indices = []
used_tokens = 0
@ -89,8 +89,8 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
@ -118,7 +118,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -167,9 +167,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
self,
model: str,
client: OpenAI,
texts: Union[list[str], str],
texts: Union[List[str], str],
extra_model_kwargs: dict,
) -> tuple[list[list[float]], int]:
) -> Tuple[List[List[float]], int]:
"""
Invoke embedding model

View File

@ -1,3 +1,5 @@
from typing import Dict, List, Type
import requests
from model_providers.core.model_runtime.errors.invoke import (
@ -12,7 +14,7 @@ from model_providers.core.model_runtime.errors.invoke import (
class _CommonOAI_API_Compat:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Generator
from decimal import Decimal
from typing import Optional, Union, cast
from typing import List, Optional, Union, cast
from urllib.parse import urljoin
import requests
@ -61,10 +61,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -98,8 +98,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -282,10 +282,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -384,7 +384,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model: str,
credentials: dict,
response: requests.Response,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -516,7 +516,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model: str,
credentials: dict,
response: requests.Response,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
response_json = response.json()
@ -649,7 +649,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return message_dict
def _num_tokens_from_string(
self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None
self, model: str, text: str, tools: Optional[List[PromptMessageTool]] = None
) -> int:
"""
Approximate num tokens for model with gpt2 tokenizer.
@ -669,8 +669,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
def _num_tokens_from_messages(
self,
model: str,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Approximate num tokens with GPT2 tokenizer.
@ -722,7 +722,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling with tiktoken package.
@ -769,8 +769,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return num_tokens
def _extract_response_tool_calls(
self, response_tool_calls: list[dict]
) -> list[AssistantPromptMessage.ToolCall]:
self, response_tool_calls: List[dict]
) -> List[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response

View File

@ -1,7 +1,7 @@
import json
import time
from decimal import Decimal
from typing import Optional
from typing import List, Optional
from urllib.parse import urljoin
import numpy as np
@ -40,7 +40,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -131,7 +131,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
embeddings=batched_embeddings, usage=usage, model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Approximate number of tokens for given messages using GPT2 tokenizer

View File

@ -1,4 +1,4 @@
from collections.abc import Generator
from typing import Dict, Generator, List, Type, Union
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.llm_entities import (
@ -54,13 +54,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
return self._generate(
model,
credentials,
@ -105,13 +105,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
prompt_messages: List[PromptMessage],
tools: Union[List[PromptMessageTool], None] = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages, tools)
def _num_tokens_from_messages(
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
self, messages: List[PromptMessage], tools: List[PromptMessageTool]
) -> int:
"""
Calculate num tokens for OpenLLM model
@ -124,13 +124,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
client = OpenLLMGenerate()
response = client.generate(
model_name=model,
@ -183,7 +183,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: OpenLLMGenerateMessage,
) -> LLMResult:
@ -206,7 +206,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: Generator[OpenLLMGenerateMessage, None, None],
) -> Generator[LLMResultChunk, None, None]:
@ -249,7 +249,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity | None:
) -> Union[AIModelEntity, None]:
"""
used to define customizable model schema
"""
@ -298,7 +298,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -20,10 +20,10 @@ class OpenLLMGenerateMessage:
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
usage: Dict[str, int] = None
stop_reason: str = ""
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"content": self.content,
@ -40,9 +40,9 @@ class OpenLLMGenerate:
server_url: str,
model_name: str,
stream: bool,
model_parameters: dict[str, Any],
stop: list[str],
prompt_messages: list[OpenLLMGenerateMessage],
model_parameters: Dict[str, Any],
stop: List[str],
prompt_messages: List[OpenLLMGenerateMessage],
user: str,
) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]:
if not server_url:

View File

@ -1,6 +1,6 @@
import time
from json import dumps
from typing import Optional
from typing import Dict, List, Optional, Type
from requests import post
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
@ -35,7 +35,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@ -89,7 +89,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@ -118,7 +118,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid server_url")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,3 +1,5 @@
from typing import Dict, List, Type
from replicate.exceptions import ModelError, ReplicateError
from model_providers.core.model_runtime.errors.invoke import (
@ -8,5 +10,5 @@ from model_providers.core.model_runtime.errors.invoke import (
class _CommonReplicate:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {InvokeBadRequestError: [ReplicateError, ModelError]}

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import Optional, Union
from typing import Generator, Optional, Union
from replicate import Client as ReplicateClient
from replicate.exceptions import ReplicateError
@ -43,10 +42,10 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -86,8 +85,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
@ -167,7 +166,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
@classmethod
def _get_customizable_model_parameter_rules(
cls, model: str, credentials: dict
) -> list[ParameterRule]:
) -> List[ParameterRule]:
version = credentials["model_version"]
client = ReplicateClient(
@ -215,8 +214,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
model: str,
credentials: dict,
prediction: Prediction,
stop: list[str],
prompt_messages: list[PromptMessage],
stop: List[str],
prompt_messages: List[PromptMessage],
) -> Generator:
index = -1
current_completion: str = ""
@ -281,8 +280,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
model: str,
credentials: dict,
prediction: Prediction,
stop: list[str],
prompt_messages: list[PromptMessage],
stop: List[str],
prompt_messages: List[PromptMessage],
) -> LLMResult:
current_completion: str = ""
stop_condition_reached = False
@ -332,7 +331,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
elif param_type == "string":
return "string"
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
messages = messages.copy() # don't mutate the original list
text = "".join(

View File

@ -1,6 +1,6 @@
import json
import time
from typing import Optional
from typing import List, Optional
from replicate import Client as ReplicateClient
@ -31,7 +31,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
self,
model: str,
credentials: dict,
texts: list[str],
texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
client = ReplicateClient(
@ -52,7 +52,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
@ -124,8 +124,8 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
client: ReplicateClient,
replicate_model_version: str,
text_input_key: str,
texts: list[str],
) -> list[list[float]]:
texts: List[str],
) -> List[List[float]]:
if text_input_key in ("text", "inputs"):
embeddings = []
for text in texts:

View File

@ -1,6 +1,6 @@
import threading
from collections.abc import Generator
from typing import Optional, Union
from typing import Dict, List, Optional, Type, Union
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@ -37,10 +37,10 @@ class SparkLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -66,8 +66,8 @@ class SparkLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -109,9 +109,9 @@ class SparkLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -171,13 +171,12 @@ class SparkLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
client: SparkLLMClient,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
:param model: model name
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
@ -222,7 +221,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
client: SparkLLMClient,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -300,7 +299,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
@ -317,7 +316,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
return text.rstrip()
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Optional, Union
from typing import List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (
@ -21,10 +21,10 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -50,10 +50,10 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -81,8 +81,8 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)

View File

@ -1,3 +1,5 @@
from typing import Dict, List, Type
from model_providers.core.model_runtime.errors.invoke import InvokeError
@ -11,7 +13,7 @@ class _CommonTongyi:
return credentials_kwargs
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Tongyi
@ -8,7 +8,7 @@ from langchain.schema import Generation, LLMResult
class EnhanceTongyi(Tongyi):
@property
def _default_params(self) -> dict[str, Any]:
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params = {"top_p": self.top_p, "api_key": self.dashscope_api_key}
@ -16,13 +16,13 @@ class EnhanceTongyi(Tongyi):
def _generate(
self,
prompts: list[str],
stop: Optional[list[str]] = None,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = []
params: dict[str, Any] = {
params: Dict[str, Any] = {
**{"model": self.model_name},
**self._default_params,
**kwargs,

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Optional, Union
from typing import Dict, List, Optional, Type, Union
from dashscope import get_tokenizer
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
@ -50,10 +50,10 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -79,14 +79,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Wrapper for code block mode
"""
@ -174,8 +174,8 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@ -220,9 +220,9 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@ -289,7 +289,7 @@ if you are not sure about the structure.
model: str,
credentials: dict,
response: DashScopeAPIResponse,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@ -326,7 +326,7 @@ if you are not sure about the structure.
model: str,
credentials: dict,
responses: Generator,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@ -412,7 +412,7 @@ if you are not sure about the structure.
return message_text
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
@ -429,8 +429,8 @@ if you are not sure about the structure.
return text.rstrip()
def _convert_prompt_messages_to_tongyi_messages(
self, prompt_messages: list[PromptMessage]
) -> list[dict]:
self, prompt_messages: List[PromptMessage]
) -> List[dict]:
"""
Convert prompt messages to tongyi messages
@ -466,7 +466,7 @@ if you are not sure about the structure.
return tongyi_messages
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -4,7 +4,7 @@ from io import BytesIO
from typing import Optional
import dashscope
from fastapi.responses import StreamingResponse
from fastapi.responses import Response, StreamingResponse
from pydub import AudioSegment
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError

View File

@ -1,9 +1,8 @@
from collections.abc import Generator
from datetime import datetime, timedelta
from enum import Enum
from json import dumps, loads
from threading import Lock
from typing import Any, Union
from typing import Any, Dict, Generator, List, Union
from requests import Response, post
@ -19,7 +18,7 @@ from model_providers.core.model_runtime.model_providers.wenxin.llm.ernie_bot_err
)
# map api_key to access_token
baidu_access_tokens: dict[str, "BaiduAccessToken"] = {}
baidu_access_tokens: Dict[str, "BaiduAccessToken"] = {}
baidu_access_tokens_lock = Lock()
@ -118,10 +117,10 @@ class ErnieMessage:
role: str = Role.USER.value
content: str
usage: dict[str, int] = None
usage: Dict[str, int] = None
stop_reason: str = ""
def to_dict(self) -> dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"content": self.content,
@ -156,11 +155,11 @@ class ErnieBotModel:
self,
model: str,
stream: bool,
messages: list[ErnieMessage],
parameters: dict[str, Any],
messages: List[ErnieMessage],
parameters: Dict[str, Any],
timeout: int,
tools: list[PromptMessageTool],
stop: list[str],
tools: List[PromptMessageTool],
stop: List[str],
user: str,
) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]:
# check parameters
@ -243,15 +242,15 @@ class ErnieBotModel:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
def _copy_messages(self, messages: List[ErnieMessage]) -> List[ErnieMessage]:
return [ErnieMessage(message.content, message.role) for message in messages]
def _check_parameters(
self,
model: str,
parameters: dict[str, Any],
tools: list[PromptMessageTool],
stop: list[str],
parameters: Dict[str, Any],
tools: List[PromptMessageTool],
stop: List[str],
) -> None:
if model not in self.api_bases:
raise BadRequestError(f"Invalid model: {model}")
@ -276,13 +275,13 @@ class ErnieBotModel:
def _build_request_body(
self,
model: str,
messages: list[ErnieMessage],
messages: List[ErnieMessage],
stream: bool,
parameters: dict[str, Any],
tools: list[PromptMessageTool],
stop: list[str],
parameters: Dict[str, Any],
tools: List[PromptMessageTool],
stop: List[str],
user: str,
) -> dict[str, Any]:
) -> Dict[str, Any]:
# if model in self.function_calling_supports:
# return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user)
return self._build_chat_request_body(
@ -292,13 +291,13 @@ class ErnieBotModel:
def _build_function_calling_request_body(
self,
model: str,
messages: list[ErnieMessage],
messages: List[ErnieMessage],
stream: bool,
parameters: dict[str, Any],
tools: list[PromptMessageTool],
stop: list[str],
parameters: Dict[str, Any],
tools: List[PromptMessageTool],
stop: List[str],
user: str,
) -> dict[str, Any]:
) -> Dict[str, Any]:
if len(messages) % 2 == 0:
raise BadRequestError("The number of messages should be odd.")
if messages[0].role == "function":
@ -311,12 +310,12 @@ class ErnieBotModel:
def _build_chat_request_body(
self,
model: str,
messages: list[ErnieMessage],
messages: List[ErnieMessage],
stream: bool,
parameters: dict[str, Any],
stop: list[str],
parameters: Dict[str, Any],
stop: List[str],
user: str,
) -> dict[str, Any]:
) -> Dict[str, Any]:
if len(messages) == 0:
raise BadRequestError("The number of messages should not be zero.")

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Dict, Generator, List, Optional, Type, Union, cast
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import (
@ -59,13 +58,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
return self._generate(
model=model,
credentials=credentials,
@ -81,13 +80,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
@ -140,12 +139,12 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
user: Union[str, None] = None,
response_format: str = "JSON",
) -> None:
"""
@ -187,15 +186,15 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
prompt_messages: List[PromptMessage],
tools: Union[List[PromptMessageTool], None] = None,
) -> int:
# tools is not supported yet
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
messages: List[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model"""
@ -234,13 +233,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
instance = ErnieBotModel(
api_key=credentials["api_key"],
secret_key=credentials["secret_key"],
@ -304,7 +303,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: ErnieMessage,
) -> LLMResult:
@ -325,7 +324,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
credentials: dict,
response: Generator[ErnieMessage, None, None],
) -> Generator:
@ -367,7 +366,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Iterator
from typing import cast
from typing import Dict, List, Union, cast
from openai import (
APIConnectionError,
@ -81,13 +81,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
"""
invoke LLM
@ -168,8 +168,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
prompt_messages: List[PromptMessage],
tools: Union[List[PromptMessageTool], None] = None,
) -> int:
"""
get number of tokens
@ -181,8 +181,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _num_tokens_from_messages(
self,
messages: list[PromptMessage],
tools: list[PromptMessageTool],
messages: List[PromptMessage],
tools: List[PromptMessageTool],
is_completion_model: bool = False,
) -> int:
def tokens(text: str):
@ -240,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
@ -284,7 +284,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
return num_tokens
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
def _convert_prompt_message_to_text(self, message: List[PromptMessage]) -> str:
"""
convert prompt message to text
"""
@ -337,7 +337,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity | None:
) -> Union[AIModelEntity, None]:
"""
used to define customizable model schema
"""
@ -412,14 +412,14 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: List[PromptMessage],
model_parameters: dict,
extra_model_kwargs: XinferenceModelExtraParameter,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
tools: Union[List[PromptMessageTool], None] = None,
stop: Union[List[str], None] = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
user: Union[str, None] = None,
) -> Union[LLMResult, Generator]:
"""
generate text from LLM
@ -525,8 +525,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _extract_response_tool_calls(
self,
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
) -> list[AssistantPromptMessage.ToolCall]:
response_tool_calls: Union[
List[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
],
) -> List[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@ -551,7 +553,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
return tool_calls
def _extract_response_function_call(
self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
self, response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall]
) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
@ -576,8 +578,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool],
resp: ChatCompletion,
) -> LLMResult:
"""
@ -633,8 +635,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool],
resp: Iterator[ChatCompletionChunk],
) -> Generator:
"""
@ -721,8 +723,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool],
resp: Completion,
) -> LLMResult:
"""
@ -765,8 +767,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool],
resp: Iterator[Completion],
) -> Generator:
"""
@ -834,7 +836,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
full_response += delta.text
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, List, Optional, Type, Union
from xinference_client.client.restful.restful_client import (
Client,
@ -41,7 +41,7 @@ class XinferenceRerankModel(RerankModel):
model: str,
credentials: dict,
query: str,
docs: list[str],
docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
@ -133,7 +133,7 @@ class XinferenceRerankModel(RerankModel):
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@ -152,7 +152,7 @@ class XinferenceRerankModel(RerankModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> AIModelEntity | None:
) -> Union[AIModelEntity, None]:
"""
used to define customizable model schema
"""

Some files were not shown because too many files have changed in this diff Show More