diff --git a/chatchat-server/chatchat/model_loaders/init_server.py b/chatchat-server/chatchat/model_loaders/init_server.py
index ef909af6..78e50965 100644
--- a/chatchat-server/chatchat/model_loaders/init_server.py
+++ b/chatchat-server/chatchat/model_loaders/init_server.py
@@ -44,8 +44,8 @@ def init_server(model_platforms_shard: Dict,
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
model_platforms_shard['provider_platforms'] = provider_platforms
-
- boot.serve(logging_conf=logging_conf)
+ boot.logging_conf(logging_conf=logging_conf)
+ boot.run()
async def pool_join_thread():
await boot.join()
diff --git a/model-providers/README.md b/model-providers/README.md
index 97ba4abb..7dabe4d7 100644
--- a/model-providers/README.md
+++ b/model-providers/README.md
@@ -42,3 +42,39 @@ make format
make format_diff
```
当你对项目的一部分进行了更改,并希望确保更改的部分格式正确,而不影响代码库的其他部分时,这个命令特别有用。
+
+
+
+### 开始使用
+
+当项目安装完成,配置这个`model_providers.yaml`文件,即可完成平台加载
+> 注意: 在您配置平台之前,请确认平台依赖完整,例如智谱平台,您需要安装智谱sdk `pip install zhipuai`
+
+model_providers包含了不同平台提供的 全局配置`provider_credential`,和模型配置`model_credential`
+不同平台所加载的配置有所不同,关于如何配置这个文件
+
+请查看包`model_providers.core.model_runtime.model_providers`下方的平台 `yaml`文件
+例如`zhipuai.yaml`,这里给出了`provider_credential_schema`,其中包含了一个变量`api_key`
+
+要加载智谱平台,操作如下
+
+- 安装sdk
+```shell
+$ pip install zhipuai
+```
+
+- 编辑`model_providers.yaml`
+
+```yaml
+
+zhipuai:
+
+ provider_credential:
+ api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.2'
+```
+
+- `model-providers`可以运行pytest 测试
+```shell
+ poetry run pytest tests/server_unit_test/test_init_server.py
+
+```
\ No newline at end of file
diff --git a/model-providers/model_providers.yaml b/model-providers/model_providers.yaml
index d88736b3..908883c7 100644
--- a/model-providers/model_providers.yaml
+++ b/model-providers/model_providers.yaml
@@ -27,3 +27,7 @@ xinference:
model_uid: 'chatglm3-6b'
+zhipuai:
+
+ provider_credential:
+ api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
\ No newline at end of file
diff --git a/model-providers/model_providers/__main__.py b/model-providers/model_providers/__main__.py
index ecde3e4a..fa4797fc 100644
--- a/model-providers/model_providers/__main__.py
+++ b/model-providers/model_providers/__main__.py
@@ -16,7 +16,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model-providers",
type=str,
- default="D:\\project\\Langchain-Chatchat\\model-providers\\model_providers.yaml",
+ default="/mnt/d/project/Langchain-Chatchat/model-providers/model_providers.yaml",
help="run model_providers servers",
dest="model_providers",
)
@@ -36,7 +36,8 @@ if __name__ == "__main__":
.build()
)
boot.set_app_event(started_event=None)
- boot.serve(logging_conf=logging_conf)
+ boot.logging_conf(logging_conf=logging_conf)
+ boot.run()
async def pool_join_thread():
await boot.join()
diff --git a/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py b/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py
index e7899a0c..77819c8a 100644
--- a/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py
+++ b/model-providers/model_providers/bootstrap_web/entities/model_provider_entities.py
@@ -50,7 +50,7 @@ class SystemConfigurationResponse(BaseModel):
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
- quota_configurations: list[QuotaConfiguration] = []
+ quota_configurations: List[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
@@ -65,8 +65,8 @@ class ProviderResponse(BaseModel):
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
- supported_model_types: list[ModelType]
- configurate_methods: list[ConfigurateMethod]
+ supported_model_types: List[ModelType]
+ configurate_methods: List[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
preferred_provider_type: ProviderType
@@ -114,7 +114,7 @@ class ProviderWithModelsResponse(BaseModel):
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
- models: list[ModelResponse]
+ models: List[ModelResponse]
def __init__(self, **data) -> None:
super().__init__(**data)
diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py
index 3c3e91af..a0661ec6 100644
--- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py
+++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py
@@ -18,6 +18,8 @@ from typing import (
cast,
)
+import tiktoken
+import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette import EventSourceResponse
@@ -67,8 +69,13 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._port = port
self._router = APIRouter()
self._app = FastAPI()
+ self._logging_conf = None
+ self._server = None
self._server_thread = None
+ def logging_conf(self,logging_conf: Optional[dict] = None):
+ self._logging_conf = logging_conf
+
@classmethod
def from_config(cls, cfg=None):
host = cfg.get("host", "127.0.0.1")
@@ -79,7 +86,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
return cls(host=host, port=port)
- def serve(self, logging_conf: Optional[dict] = None):
+ def run(self):
self._app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -125,18 +132,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
self._app.include_router(self._router)
config = Config(
- app=self._app, host=self._host, port=self._port, log_config=logging_conf
+ app=self._app, host=self._host, port=self._port, log_config=self._logging_conf
)
- server = Server(config)
+ self._server = Server(config)
def run_server():
- server.run()
+ self._server.shutdown_timeout = 2 # 设置为2秒
+
+ self._server.run()
self._server_thread = threading.Thread(target=run_server)
self._server_thread.start()
- async def join(self):
- await self._server_thread.join()
+ def destroy(self):
+
+ logger.info("Shutting down server")
+ self._server.should_exit = True # 设置退出标志
+ self._server.shutdown() # 停止服务器
+
+ self.join()
+
+ def join(self):
+ self._server_thread.join()
+
def set_app_event(self, started_event: mp.Event = None):
@self._app.on_event("startup")
@@ -159,7 +177,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
async def list_models(self, provider: str, request: Request):
logger.info(f"Received list_models request for provider: {provider}")
# 返回ModelType所有的枚举
- llm_models: list[AIModelEntity] = []
+ llm_models: List[AIModelEntity] = []
for model_type in ModelType.__members__.values():
try:
provider_model_bundle = (
@@ -176,7 +194,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
)
logger.error(e)
- # models list[AIModelEntity]转换称List[ModelCard]
+ # modelsList[AIModelEntity]转换称List[ModelCard]
models_list = [
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
@@ -191,16 +209,41 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
logger.info(
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
)
- model_instance = self._provider_manager.get_model_instance(
- provider=provider,
- model_type=ModelType.TEXT_EMBEDDING,
- model=embeddings_request.model,
- )
- texts = embeddings_request.input
- if isinstance(texts, str):
- texts = [texts]
- response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
- return await openai_embedding_text(response)
+ try:
+ model_instance = self._provider_manager.get_model_instance(
+ provider=provider,
+ model_type=ModelType.TEXT_EMBEDDING,
+ model=embeddings_request.model,
+ )
+
+ # 判断embeddings_request.input是否为list
+ input = ''
+ if isinstance(embeddings_request.input, list):
+ tokens = embeddings_request.input
+ try:
+ encoding = tiktoken.encoding_for_model(embeddings_request.model)
+ except KeyError:
+ logger.warning("Warning: model not found. Using cl100k_base encoding.")
+ model = "cl100k_base"
+ encoding = tiktoken.get_encoding(model)
+ for i, token in enumerate(tokens):
+ text = encoding.decode(token)
+ input += text
+
+ else:
+ input = embeddings_request.input
+
+ response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
+ return await openai_embedding_text(response)
+
+ except ValueError as e:
+ logger.error(f"Error while creating embeddings: {str(e)}")
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
+ except InvokeError as e:
+ logger.error(f"Error while creating embeddings: {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
+ )
async def create_chat_completion(
self, provider: str, request: Request, chat_request: ChatCompletionRequest
@@ -254,9 +297,11 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
else:
return await openai_chat_completion(response)
except ValueError as e:
+ logger.error(f"Error while creating chat completion: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except InvokeError as e:
+ logger.error(f"Error while creating chat completion: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@@ -273,10 +318,11 @@ def run(
cfg=cfg.get("run_openai_api", {})
)
api.set_app_event(started_event=started_event)
- api.serve(logging_conf=logging_conf)
+ api.logging_conf(logging_conf=logging_conf)
+ api.run()
async def pool_join_thread():
- await api.join()
+ api.join()
asyncio.run(pool_join_thread())
except SystemExit:
diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py
index 2945c0ba..2bd364f3 100644
--- a/model-providers/model_providers/core/bootstrap/openai_protocol.py
+++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py
@@ -145,7 +145,7 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = 256
- stop: Optional[list[str]] = None
+ stop: Optional[List[str]] = None
stream: Optional[bool] = False
def to_model_parameters_dict(self, *args, **kwargs):
diff --git a/model-providers/model_providers/core/bootstrap/providers_wapper.py b/model-providers/model_providers/core/bootstrap/providers_wapper.py
index d958a999..9d1858ed 100644
--- a/model-providers/model_providers/core/bootstrap/providers_wapper.py
+++ b/model-providers/model_providers/core/bootstrap/providers_wapper.py
@@ -112,7 +112,7 @@ class ProvidersWrapper:
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
- providers_with_models: list[ProviderWithModelsResponse] = []
+ providers_with_models: List[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
diff --git a/model-providers/model_providers/core/entities/application_entities.py b/model-providers/model_providers/core/entities/application_entities.py
index 9a5e0ff4..263693cc 100644
--- a/model-providers/model_providers/core/entities/application_entities.py
+++ b/model-providers/model_providers/core/entities/application_entities.py
@@ -21,9 +21,9 @@ class ModelConfigEntity(BaseModel):
model_schema: AIModelEntity
mode: str
provider_model_bundle: ProviderModelBundle
- credentials: dict[str, Any] = {}
- parameters: dict[str, Any] = {}
- stop: list[str] = []
+ credentials: Dict[str, Any] = {}
+ parameters: Dict[str, Any] = {}
+ stop: List[str] = []
class AdvancedChatMessageEntity(BaseModel):
@@ -40,7 +40,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
Advanced Chat Prompt Template Entity.
"""
- messages: list[AdvancedChatMessageEntity]
+ messages: List[AdvancedChatMessageEntity]
class AdvancedCompletionPromptTemplateEntity(BaseModel):
@@ -102,7 +102,7 @@ class ExternalDataVariableEntity(BaseModel):
variable: str
type: str
- config: dict[str, Any] = {}
+ config: Dict[str, Any] = {}
class DatasetRetrieveConfigEntity(BaseModel):
@@ -146,7 +146,7 @@ class DatasetEntity(BaseModel):
Dataset Config Entity.
"""
- dataset_ids: list[str]
+ dataset_ids: List[str]
retrieve_config: DatasetRetrieveConfigEntity
@@ -156,7 +156,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
"""
type: str
- config: dict[str, Any] = {}
+ config: Dict[str, Any] = {}
class TextToSpeechEntity(BaseModel):
@@ -185,7 +185,7 @@ class AgentToolEntity(BaseModel):
provider_type: Literal["builtin", "api"]
provider_id: str
tool_name: str
- tool_parameters: dict[str, Any] = {}
+ tool_parameters: Dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
@@ -234,7 +234,7 @@ class AgentEntity(BaseModel):
model: str
strategy: Strategy
prompt: Optional[AgentPromptEntity] = None
- tools: list[AgentToolEntity] = None
+ tools: List[AgentToolEntity] = None
max_iteration: int = 5
@@ -245,7 +245,7 @@ class AppOrchestrationConfigEntity(BaseModel):
model_config: ModelConfigEntity
prompt_template: PromptTemplateEntity
- external_data_variables: list[ExternalDataVariableEntity] = []
+ external_data_variables: List[ExternalDataVariableEntity] = []
agent: Optional[AgentEntity] = None
# features
@@ -319,13 +319,13 @@ class ApplicationGenerateEntity(BaseModel):
app_orchestration_config_entity: AppOrchestrationConfigEntity
conversation_id: Optional[str] = None
- inputs: dict[str, str]
+ inputs: Dict[str, str]
query: Optional[str] = None
- files: list[FileObj] = []
+ files: List[FileObj] = []
user_id: str
# extras
stream: bool
invoke_from: InvokeFrom
# extra parameters, like: auto_generate_conversation_name
- extras: dict[str, Any] = {}
+ extras: Dict[str, Any] = {}
diff --git a/model-providers/model_providers/core/entities/message_entities.py b/model-providers/model_providers/core/entities/message_entities.py
index b7ad8172..52aa3fa0 100644
--- a/model-providers/model_providers/core/entities/message_entities.py
+++ b/model-providers/model_providers/core/entities/message_entities.py
@@ -47,12 +47,12 @@ class ImagePromptMessageFile(PromptMessageFile):
class LCHumanMessageWithFiles(HumanMessage):
- # content: Union[str, list[Union[str, Dict]]]
+ # content: Union[str,List[Union[str, Dict]]]
content: str
- files: list[PromptMessageFile]
+ files: List[PromptMessageFile]
-def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
+def lc_messages_to_prompt_messages(messages: List[BaseMessage]) -> List[PromptMessage]:
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
@@ -109,8 +109,8 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe
def prompt_messages_to_lc_messages(
- prompt_messages: list[PromptMessage],
-) -> list[BaseMessage]:
+ prompt_messages: List[PromptMessage],
+) -> List[BaseMessage]:
messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
diff --git a/model-providers/model_providers/core/entities/model_entities.py b/model-providers/model_providers/core/entities/model_entities.py
index 20e5dbc9..cfaf6b82 100644
--- a/model-providers/model_providers/core/entities/model_entities.py
+++ b/model-providers/model_providers/core/entities/model_entities.py
@@ -1,5 +1,5 @@
from enum import Enum
-from typing import Optional
+from typing import List, Optional
from pydantic import BaseModel
@@ -31,7 +31,7 @@ class SimpleModelProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
- supported_model_types: list[ModelType]
+ supported_model_types: List[ModelType]
def __init__(self, provider_entity: ProviderEntity) -> None:
"""
@@ -66,7 +66,7 @@ class DefaultModelProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
- supported_model_types: list[ModelType]
+ supported_model_types: List[ModelType]
class DefaultModelEntity(BaseModel):
diff --git a/model-providers/model_providers/core/entities/provider_configuration.py b/model-providers/model_providers/core/entities/provider_configuration.py
index 22d42587..a068bd9e 100644
--- a/model-providers/model_providers/core/entities/provider_configuration.py
+++ b/model-providers/model_providers/core/entities/provider_configuration.py
@@ -1,9 +1,8 @@
import datetime
import json
import logging
-from collections.abc import Iterator
from json import JSONDecodeError
-from typing import Optional
+from typing import Dict, Iterator, List, Optional
from pydantic import BaseModel
@@ -162,7 +161,7 @@ class ProviderConfiguration(BaseModel):
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False
- ) -> list[ModelWithProviderEntity]:
+ ) -> List[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
@@ -189,8 +188,8 @@ class ProviderConfiguration(BaseModel):
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_custom_provider_models(
- self, model_types: list[ModelType], provider_instance: ModelProvider
- ) -> list[ModelWithProviderEntity]:
+ self, model_types: List[ModelType], provider_instance: ModelProvider
+ ) -> List[ModelWithProviderEntity]:
"""
Get custom provider models.
@@ -266,7 +265,7 @@ class ProviderConfigurations(BaseModel):
Model class for provider configuration dict.
"""
- configurations: dict[str, ProviderConfiguration] = {}
+ configurations: Dict[str, ProviderConfiguration] = {}
def __init__(self):
super().__init__()
@@ -276,7 +275,7 @@ class ProviderConfigurations(BaseModel):
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False,
- ) -> list[ModelWithProviderEntity]:
+ ) -> List[ModelWithProviderEntity]:
"""
Get available models.
@@ -317,7 +316,7 @@ class ProviderConfigurations(BaseModel):
return all_models
- def to_list(self) -> list[ProviderConfiguration]:
+ def to_list(self) -> List[ProviderConfiguration]:
"""
Convert to list.
diff --git a/model-providers/model_providers/core/entities/provider_entities.py b/model-providers/model_providers/core/entities/provider_entities.py
index ba4a3fb1..7b0705db 100644
--- a/model-providers/model_providers/core/entities/provider_entities.py
+++ b/model-providers/model_providers/core/entities/provider_entities.py
@@ -1,5 +1,5 @@
from enum import Enum
-from typing import Optional
+from typing import List, Optional
from pydantic import BaseModel
@@ -68,7 +68,7 @@ class QuotaConfiguration(BaseModel):
quota_limit: int
quota_used: int
is_valid: bool
- restrict_models: list[RestrictModel] = []
+ restrict_models: List[RestrictModel] = []
class SystemConfiguration(BaseModel):
@@ -78,7 +78,7 @@ class SystemConfiguration(BaseModel):
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
- quota_configurations: list[QuotaConfiguration] = []
+ quota_configurations: List[QuotaConfiguration] = []
credentials: Optional[dict] = None
@@ -106,4 +106,4 @@ class CustomConfiguration(BaseModel):
"""
provider: Optional[CustomProviderConfiguration] = None
- models: list[CustomModelConfiguration] = []
+ models: List[CustomModelConfiguration] = []
diff --git a/model-providers/model_providers/core/entities/queue_entities.py b/model-providers/model_providers/core/entities/queue_entities.py
index 7ba21aa6..f72080cb 100644
--- a/model-providers/model_providers/core/entities/queue_entities.py
+++ b/model-providers/model_providers/core/entities/queue_entities.py
@@ -68,7 +68,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
event = QueueEvent.RETRIEVER_RESOURCES
- retriever_resources: list[dict]
+ retriever_resources: List[dict]
class AnnotationReplyEvent(AppQueueEvent):
diff --git a/model-providers/model_providers/core/model_manager.py b/model-providers/model_providers/core/model_manager.py
index af896423..fc045509 100644
--- a/model-providers/model_providers/core/model_manager.py
+++ b/model-providers/model_providers/core/model_manager.py
@@ -1,5 +1,4 @@
-from collections.abc import Generator
-from typing import IO, Optional, Union, cast
+from typing import IO, Generator, List, Optional, Union, cast
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.core.model_runtime.callbacks.base_callback import Callback
@@ -68,13 +67,13 @@ class ModelInstance:
def invoke_llm(
self,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: Optional[dict] = None,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@@ -105,7 +104,7 @@ class ModelInstance:
)
def invoke_text_embedding(
- self, texts: list[str], user: Optional[str] = None
+ self, texts: List[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke large language model
@@ -125,7 +124,7 @@ class ModelInstance:
def invoke_rerank(
self,
query: str,
- docs: list[str],
+ docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
diff --git a/model-providers/model_providers/core/model_runtime/README_CN.md b/model-providers/model_providers/core/model_runtime/README_CN.md
index 3664fa2c..de984853 100644
--- a/model-providers/model_providers/core/model_runtime/README_CN.md
+++ b/model-providers/model_providers/core/model_runtime/README_CN.md
@@ -69,9 +69,9 @@ Model Runtime 分三层:
在这里我们需要先区分模型参数与模型凭据。
- - 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在DifyRuntime中,他们的参数名一般为**model_parameters: dict[str, any]**。
+ - 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在DifyRuntime中,他们的参数名一般为**model_parameters: Dict[str, any]**。
- - 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中,他们的参数名一般为**credentials: dict[str, any]**,Provider层的credentials会直接被传递到这一层,不需要再单独定义。
+ - 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中,他们的参数名一般为**credentials: Dict[str, any]**,Provider层的credentials会直接被传递到这一层,不需要再单独定义。
## 下一步
diff --git a/model-providers/model_providers/core/model_runtime/callbacks/base_callback.py b/model-providers/model_providers/core/model_runtime/callbacks/base_callback.py
index f7b8c3e5..c7a51cd1 100644
--- a/model-providers/model_providers/core/model_runtime/callbacks/base_callback.py
+++ b/model-providers/model_providers/core/model_runtime/callbacks/base_callback.py
@@ -1,5 +1,5 @@
from abc import ABC
-from typing import Optional
+from typing import List, Optional
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@@ -33,10 +33,10 @@ class Callback(ABC):
llm_instance: AIModel,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -61,10 +61,10 @@ class Callback(ABC):
chunk: LLMResultChunk,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@@ -90,10 +90,10 @@ class Callback(ABC):
result: LLMResult,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -119,10 +119,10 @@ class Callback(ABC):
ex: Exception,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
diff --git a/model-providers/model_providers/core/model_runtime/callbacks/logging_callback.py b/model-providers/model_providers/core/model_runtime/callbacks/logging_callback.py
index be78e354..287a9f67 100644
--- a/model-providers/model_providers/core/model_runtime/callbacks/logging_callback.py
+++ b/model-providers/model_providers/core/model_runtime/callbacks/logging_callback.py
@@ -1,7 +1,7 @@
import json
import logging
import sys
-from typing import Optional
+from typing import List, Optional
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import (
@@ -23,10 +23,10 @@ class LoggingCallback(Callback):
llm_instance: AIModel,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -79,10 +79,10 @@ class LoggingCallback(Callback):
chunk: LLMResultChunk,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
):
@@ -109,10 +109,10 @@ class LoggingCallback(Callback):
result: LLMResult,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
@@ -154,10 +154,10 @@ class LoggingCallback(Callback):
ex: Exception,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> None:
diff --git a/model-providers/model_providers/core/model_runtime/docs/en_US/interfaces.md b/model-providers/model_providers/core/model_runtime/docs/en_US/interfaces.md
index dc70bfad..0f3cfdd0 100644
--- a/model-providers/model_providers/core/model_runtime/docs/en_US/interfaces.md
+++ b/model-providers/model_providers/core/model_runtime/docs/en_US/interfaces.md
@@ -71,7 +71,7 @@ All models need to uniformly implement the following 2 methods:
```python
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@@ -95,8 +95,8 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
```python
def _invoke(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+ prompt_messages:List[PromptMessage], model_parameters: dict,
+ tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@@ -157,8 +157,8 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
```python
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@@ -196,7 +196,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl
```python
def _invoke(self, model: str, credentials: dict,
- texts: list[str], user: Optional[str] = None) \
+ texts:List[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
@@ -230,7 +230,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl
- Pre-calculating Tokens
```python
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts:List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -251,7 +251,7 @@ Inherit the `__base.rerank_model.RerankModel` base class and implement the follo
```python
def _invoke(self, model: str, credentials: dict,
- query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
+ query: str, docs:List[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
@@ -498,7 +498,7 @@ class PromptMessage(ABC, BaseModel):
Model class for prompt message.
"""
role: PromptMessageRole
- content: Optional[str | list[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
+ content: Optional[str |List[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
name: Optional[str] = None
```
@@ -539,7 +539,7 @@ class AssistantPromptMessage(PromptMessage):
function: ToolCallFunction # tool call information
role: PromptMessageRole = PromptMessageRole.ASSISTANT
- tool_calls: list[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
+ tool_calls:List[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
```
Where `tool_calls` are the list of `tool calls` returned by the model after invoking the model with the `tools` input.
@@ -593,7 +593,7 @@ class LLMResult(BaseModel):
Model class for llm result.
"""
model: str # Actual used modele
- prompt_messages: list[PromptMessage] # prompt messages
+ prompt_messages:List[PromptMessage] # prompt messages
message: AssistantPromptMessage # response message
usage: LLMUsage # usage info
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
@@ -624,7 +624,7 @@ class LLMResultChunk(BaseModel):
Model class for llm result chunk.
"""
model: str # Actual used modele
- prompt_messages: list[PromptMessage] # prompt messages
+ prompt_messages:List[PromptMessage] # prompt messages
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
delta: LLMResultChunkDelta
```
@@ -660,7 +660,7 @@ class TextEmbeddingResult(BaseModel):
Model class for text embedding result.
"""
model: str # Actual model used
- embeddings: list[list[float]] # List of embedding vectors, corresponding to the input texts list
+ embeddings:List[List[float]] # List of embedding vectors, corresponding to the input texts list
usage: EmbeddingUsage # Usage information
```
@@ -690,7 +690,7 @@ class RerankResult(BaseModel):
Model class for rerank result.
"""
model: str # Actual model used
- docs: list[RerankDocument] # Reranked document list
+ docs:List[RerankDocument] # Reranked document list
```
### RerankDocument
diff --git a/model-providers/model_providers/core/model_runtime/docs/en_US/provider_scale_out.md b/model-providers/model_providers/core/model_runtime/docs/en_US/provider_scale_out.md
index ba356c5c..9a1ba736 100644
--- a/model-providers/model_providers/core/model_runtime/docs/en_US/provider_scale_out.md
+++ b/model-providers/model_providers/core/model_runtime/docs/en_US/provider_scale_out.md
@@ -160,8 +160,8 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
```python
def _invoke(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ prompt_messages:List[PromptMessage], model_parameters: dict,
+ tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@@ -184,8 +184,8 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
```python
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@@ -226,7 +226,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
```python
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md b/model-providers/model_providers/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md
index 7b3a8edb..8945d0f0 100644
--- a/model-providers/model_providers/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md
+++ b/model-providers/model_providers/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md
@@ -126,8 +126,8 @@ provider_credential_schema:
```python
def _invoke(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ prompt_messages:List[PromptMessage], model_parameters: dict,
+ tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@@ -166,8 +166,8 @@ provider_credential_schema:
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@@ -283,7 +283,7 @@ provider_credential_schema:
```python
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/docs/zh_Hans/interfaces.md b/model-providers/model_providers/core/model_runtime/docs/zh_Hans/interfaces.md
index 743e575d..da921012 100644
--- a/model-providers/model_providers/core/model_runtime/docs/zh_Hans/interfaces.md
+++ b/model-providers/model_providers/core/model_runtime/docs/zh_Hans/interfaces.md
@@ -80,7 +80,7 @@ class XinferenceProvider(Provider):
```python
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@@ -95,7 +95,7 @@ class XinferenceProvider(Provider):
```python
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
@@ -127,8 +127,8 @@ class XinferenceProvider(Provider):
```python
def _invoke(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ prompt_messages:List[PromptMessage], model_parameters: dict,
+ tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@@ -189,8 +189,8 @@ class XinferenceProvider(Provider):
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@@ -232,7 +232,7 @@ class XinferenceProvider(Provider):
```python
def _invoke(self, model: str, credentials: dict,
- texts: list[str], user: Optional[str] = None) \
+ texts:List[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
@@ -266,7 +266,7 @@ class XinferenceProvider(Provider):
- 预计算 tokens
```python
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts:List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -289,7 +289,7 @@ class XinferenceProvider(Provider):
```python
def _invoke(self, model: str, credentials: dict,
- query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
+ query: str, docs:List[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
@@ -538,7 +538,7 @@ class PromptMessage(ABC, BaseModel):
Model class for prompt message.
"""
role: PromptMessageRole # 消息角色
- content: Optional[str | list[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
+ content: Optional[str |List[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
name: Optional[str] = None # 名称,可选。
```
@@ -579,7 +579,7 @@ class AssistantPromptMessage(PromptMessage):
function: ToolCallFunction # 工具调用信息
role: PromptMessageRole = PromptMessageRole.ASSISTANT
- tool_calls: list[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools,并且模型认为需要调用工具时返回)
+ tool_calls:List[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools,并且模型认为需要调用工具时返回)
```
其中 `tool_calls` 为调用模型传入 `tools` 后,由模型返回的 `tool call` 列表。
@@ -633,7 +633,7 @@ class LLMResult(BaseModel):
Model class for llm result.
"""
model: str # 实际使用模型
- prompt_messages: list[PromptMessage] # prompt 消息列表
+ prompt_messages:List[PromptMessage] # prompt 消息列表
message: AssistantPromptMessage # 回复消息
usage: LLMUsage # 使用的 tokens 及费用信息
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
@@ -664,7 +664,7 @@ class LLMResultChunk(BaseModel):
Model class for llm result chunk.
"""
model: str # 实际使用模型
- prompt_messages: list[PromptMessage] # prompt 消息列表
+ prompt_messages:List[PromptMessage] # prompt 消息列表
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
delta: LLMResultChunkDelta # 每个迭代存在变化的内容
```
@@ -700,7 +700,7 @@ class TextEmbeddingResult(BaseModel):
Model class for text embedding result.
"""
model: str # 实际使用模型
- embeddings: list[list[float]] # embedding 向量列表,对应传入的 texts 列表
+ embeddings:List[List[float]] # embedding 向量列表,对应传入的 texts 列表
usage: EmbeddingUsage # 使用信息
```
@@ -730,7 +730,7 @@ class RerankResult(BaseModel):
Model class for rerank result.
"""
model: str # 实际使用模型
- docs: list[RerankDocument] # 重排后的分段列表
+ docs:List[RerankDocument] # 重排后的分段列表
```
### RerankDocument
diff --git a/model-providers/model_providers/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md b/model-providers/model_providers/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md
index 56f379a9..ed569340 100644
--- a/model-providers/model_providers/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md
+++ b/model-providers/model_providers/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md
@@ -76,8 +76,8 @@ pricing: # 价格信息
```python
def _invoke(self, model: str, credentials: dict,
- prompt_messages: list[PromptMessage], model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+ prompt_messages:List[PromptMessage], model_parameters: dict,
+ tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
@@ -116,8 +116,8 @@ pricing: # 价格信息
若模型未提供预计算 tokens 接口,可直接返回 0。
```python
- def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
@@ -158,7 +158,7 @@ pricing: # 价格信息
```python
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/entities/defaults.py b/model-providers/model_providers/core/model_runtime/entities/defaults.py
index 98719aac..a9092830 100644
--- a/model-providers/model_providers/core/model_runtime/entities/defaults.py
+++ b/model-providers/model_providers/core/model_runtime/entities/defaults.py
@@ -1,8 +1,10 @@
+from typing import Dict
+
from model_providers.core.model_runtime.entities.model_entities import (
DefaultParameterName,
)
-PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
+PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
DefaultParameterName.TEMPERATURE: {
"label": {
"en_US": "Temperature",
diff --git a/model-providers/model_providers/core/model_runtime/entities/llm_entities.py b/model-providers/model_providers/core/model_runtime/entities/llm_entities.py
index eafdfb2b..8976ff8c 100644
--- a/model-providers/model_providers/core/model_runtime/entities/llm_entities.py
+++ b/model-providers/model_providers/core/model_runtime/entities/llm_entities.py
@@ -1,6 +1,6 @@
from decimal import Decimal
from enum import Enum
-from typing import Optional
+from typing import List, Optional
from pydantic import BaseModel
@@ -78,7 +78,7 @@ class LLMResult(BaseModel):
"""
model: str
- prompt_messages: list[PromptMessage]
+ prompt_messages: List[PromptMessage]
message: AssistantPromptMessage
usage: LLMUsage
system_fingerprint: Optional[str] = None
@@ -101,7 +101,7 @@ class LLMResultChunk(BaseModel):
"""
model: str
- prompt_messages: list[PromptMessage]
+ prompt_messages: List[PromptMessage]
system_fingerprint: Optional[str] = None
delta: LLMResultChunkDelta
diff --git a/model-providers/model_providers/core/model_runtime/entities/message_entities.py b/model-providers/model_providers/core/model_runtime/entities/message_entities.py
index a66294ad..0fd0ed17 100644
--- a/model-providers/model_providers/core/model_runtime/entities/message_entities.py
+++ b/model-providers/model_providers/core/model_runtime/entities/message_entities.py
@@ -1,6 +1,6 @@
from abc import ABC
from enum import Enum
-from typing import Optional
+from typing import List, Optional, Union
from pydantic import BaseModel
@@ -110,7 +110,7 @@ class PromptMessage(ABC, BaseModel):
"""
role: PromptMessageRole
- content: Optional[str | list[PromptMessageContent]] = None
+ content: Optional[Union[str, List[PromptMessageContent]]] = None
name: Optional[str] = None
@@ -145,7 +145,7 @@ class AssistantPromptMessage(PromptMessage):
function: ToolCallFunction
role: PromptMessageRole = PromptMessageRole.ASSISTANT
- tool_calls: list[ToolCall] = []
+ tool_calls: List[ToolCall] = []
class SystemPromptMessage(PromptMessage):
diff --git a/model-providers/model_providers/core/model_runtime/entities/model_entities.py b/model-providers/model_providers/core/model_runtime/entities/model_entities.py
index 0fea8c1d..5cc6e80b 100644
--- a/model-providers/model_providers/core/model_runtime/entities/model_entities.py
+++ b/model-providers/model_providers/core/model_runtime/entities/model_entities.py
@@ -1,6 +1,6 @@
from decimal import Decimal
from enum import Enum
-from typing import Any, List, Optional
+from typing import Any, Dict, List, Optional
from pydantic import BaseModel
@@ -158,9 +158,9 @@ class ProviderModel(BaseModel):
model: str
label: I18nObject
model_type: ModelType
- features: Optional[list[ModelFeature]] = None
+ features: Optional[List[ModelFeature]] = None
fetch_from: FetchFrom
- model_properties: dict[ModelPropertyKey, Any]
+ model_properties: Dict[ModelPropertyKey, Any]
deprecated: bool = False
class Config:
@@ -182,7 +182,7 @@ class ParameterRule(BaseModel):
min: Optional[float] = None
max: Optional[float] = None
precision: Optional[int] = None
- options: list[str] = []
+ options: List[str] = []
class PriceConfig(BaseModel):
@@ -201,7 +201,7 @@ class AIModelEntity(ProviderModel):
Model class for AI model.
"""
- parameter_rules: list[ParameterRule] = []
+ parameter_rules: List[ParameterRule] = []
pricing: Optional[PriceConfig] = None
diff --git a/model-providers/model_providers/core/model_runtime/entities/provider_entities.py b/model-providers/model_providers/core/model_runtime/entities/provider_entities.py
index 21b610ad..2bfee500 100644
--- a/model-providers/model_providers/core/model_runtime/entities/provider_entities.py
+++ b/model-providers/model_providers/core/model_runtime/entities/provider_entities.py
@@ -1,5 +1,5 @@
from enum import Enum
-from typing import Optional
+from typing import List, Optional
from pydantic import BaseModel
@@ -48,7 +48,7 @@ class FormOption(BaseModel):
label: I18nObject
value: str
- show_on: list[FormShowOnObject] = []
+ show_on: List[FormShowOnObject] = []
def __init__(self, **data):
super().__init__(**data)
@@ -66,10 +66,10 @@ class CredentialFormSchema(BaseModel):
type: FormType
required: bool = True
default: Optional[str] = None
- options: Optional[list[FormOption]] = None
+ options: Optional[List[FormOption]] = None
placeholder: Optional[I18nObject] = None
max_length: int = 0
- show_on: list[FormShowOnObject] = []
+ show_on: List[FormShowOnObject] = []
class ProviderCredentialSchema(BaseModel):
@@ -77,7 +77,7 @@ class ProviderCredentialSchema(BaseModel):
Model class for provider credential schema.
"""
- credential_form_schemas: list[CredentialFormSchema]
+ credential_form_schemas: List[CredentialFormSchema]
class FieldModelSchema(BaseModel):
@@ -91,7 +91,7 @@ class ModelCredentialSchema(BaseModel):
"""
model: FieldModelSchema
- credential_form_schemas: list[CredentialFormSchema]
+ credential_form_schemas: List[CredentialFormSchema]
class SimpleProviderEntity(BaseModel):
@@ -103,8 +103,8 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
- supported_model_types: list[ModelType]
- models: list[AIModelEntity] = []
+ supported_model_types: List[ModelType]
+ models: List[AIModelEntity] = []
class ProviderHelpEntity(BaseModel):
@@ -128,9 +128,9 @@ class ProviderEntity(BaseModel):
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
- supported_model_types: list[ModelType]
- configurate_methods: list[ConfigurateMethod]
- models: list[ProviderModel] = []
+ supported_model_types: List[ModelType]
+ configurate_methods: List[ConfigurateMethod]
+ models: List[ProviderModel] = []
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
diff --git a/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py b/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py
index 99709e1b..034a7286 100644
--- a/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py
+++ b/model-providers/model_providers/core/model_runtime/entities/rerank_entities.py
@@ -1,3 +1,5 @@
+from typing import List
+
from pydantic import BaseModel
@@ -17,4 +19,4 @@ class RerankResult(BaseModel):
"""
model: str
- docs: list[RerankDocument]
+ docs: List[RerankDocument]
diff --git a/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py b/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py
index fa2172a0..454e41ee 100644
--- a/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py
+++ b/model-providers/model_providers/core/model_runtime/entities/text_embedding_entities.py
@@ -1,4 +1,5 @@
from decimal import Decimal
+from typing import List
from pydantic import BaseModel
@@ -25,5 +26,5 @@ class TextEmbeddingResult(BaseModel):
"""
model: str
- embeddings: list[list[float]]
+ embeddings: List[List[float]]
usage: EmbeddingUsage
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/ai_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/ai_model.py
index 2c3233d4..6d34f88a 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/__base/ai_model.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/ai_model.py
@@ -1,7 +1,7 @@
import decimal
import os
from abc import ABC, abstractmethod
-from typing import Optional
+from typing import Dict, List, Optional, Type
import yaml
@@ -35,7 +35,7 @@ class AIModel(ABC):
"""
model_type: ModelType
- model_schemas: list[AIModelEntity] = None
+ model_schemas: List[AIModelEntity] = None
started_at: float = 0
@abstractmethod
@@ -51,7 +51,7 @@ class AIModel(ABC):
@property
@abstractmethod
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@@ -133,7 +133,7 @@ class AIModel(ABC):
currency=price_config.currency,
)
- def predefined_models(self) -> list[AIModelEntity]:
+ def predefined_models(self) -> List[AIModelEntity]:
"""
Get all predefined models for given provider.
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/large_language_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/large_language_model.py
index 2b0a1e20..2b09bfdb 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/__base/large_language_model.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/large_language_model.py
@@ -3,8 +3,7 @@ import os
import re
import time
from abc import abstractmethod
-from collections.abc import Generator
-from typing import Optional, Union
+from typing import Generator, List, Optional, Union
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.callbacks.logging_callback import (
@@ -47,13 +46,13 @@ class LargeLanguageModel(AIModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: Optional[dict] = None,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
@@ -170,13 +169,13 @@ class LargeLanguageModel(AIModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
@@ -290,7 +289,7 @@ if you are not sure about the structure.
def _code_block_mode_stream_processor(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None],
) -> Generator[LLMResultChunk, None, None]:
"""
@@ -428,13 +427,13 @@ if you are not sure about the structure.
model: str,
result: Generator,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> Generator:
"""
Invoke result generator
@@ -498,10 +497,10 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -525,8 +524,8 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -539,7 +538,7 @@ if you are not sure about the structure.
"""
raise NotImplementedError
- def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
+ def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
@@ -575,7 +574,7 @@ if you are not sure about the structure.
index += 1
time.sleep(0.01)
- def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRule]:
+ def get_parameter_rules(self, model: str, credentials: dict) -> List[ParameterRule]:
"""
Get parameter rules
@@ -658,13 +657,13 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> None:
"""
Trigger before invoke callbacks
@@ -706,13 +705,13 @@ if you are not sure about the structure.
chunk: LLMResultChunk,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> None:
"""
Trigger new chunk callbacks
@@ -755,13 +754,13 @@ if you are not sure about the structure.
model: str,
result: LLMResult,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> None:
"""
Trigger after invoke callbacks
@@ -805,13 +804,13 @@ if you are not sure about the structure.
model: str,
ex: Exception,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> None:
"""
Trigger invoke error callbacks
@@ -911,7 +910,7 @@ if you are not sure about the structure.
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
)
elif parameter_rule.type == ParameterType.FLOAT:
- if not isinstance(parameter_value, float | int):
+ if not isinstance(parameter_value, (float, int)):
raise ValueError(
f"Model Parameter {parameter_name} should be float."
)
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/model_provider.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/model_provider.py
index 48f4a942..40fea584 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/__base/model_provider.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/model_provider.py
@@ -1,6 +1,7 @@
import importlib
import os
from abc import ABC, abstractmethod
+from typing import Dict, List
import yaml
@@ -14,7 +15,7 @@ from model_providers.core.model_runtime.model_providers.__base.ai_model import A
class ModelProvider(ABC):
provider_schema: ProviderEntity = None
- model_instance_map: dict[str, AIModel] = {}
+ model_instance_map: Dict[str, AIModel] = {}
@abstractmethod
def validate_provider_credentials(self, credentials: dict) -> None:
@@ -65,7 +66,7 @@ class ModelProvider(ABC):
return provider_schema
- def models(self, model_type: ModelType) -> list[AIModelEntity]:
+ def models(self, model_type: ModelType) -> List[AIModelEntity]:
"""
Get all models for given model type
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/rerank_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/rerank_model.py
index 6fdafde9..c37cc7d7 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/__base/rerank_model.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/rerank_model.py
@@ -1,6 +1,6 @@
import time
from abc import abstractmethod
-from typing import Optional
+from typing import List, Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
@@ -19,7 +19,7 @@ class RerankModel(AIModel):
model: str,
credentials: dict,
query: str,
- docs: list[str],
+ docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
@@ -51,7 +51,7 @@ class RerankModel(AIModel):
model: str,
credentials: dict,
query: str,
- docs: list[str],
+ docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/text2img_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/text2img_model.py
index 058c910e..4eab10fc 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/__base/text2img_model.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/text2img_model.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import IO, Optional
+from typing import IO, List, Optional
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
@@ -19,7 +19,7 @@ class Text2ImageModel(AIModel):
prompt: str,
model_parameters: dict,
user: Optional[str] = None,
- ) -> list[IO[bytes]]:
+ ) -> List[IO[bytes]]:
"""
Invoke Text2Image model
@@ -44,7 +44,7 @@ class Text2ImageModel(AIModel):
prompt: str,
model_parameters: dict,
user: Optional[str] = None,
- ) -> list[IO[bytes]]:
+ ) -> List[IO[bytes]]:
"""
Invoke Text2Image model
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/__base/text_embedding_model.py b/model-providers/model_providers/core/model_runtime/model_providers/__base/text_embedding_model.py
index d46b412e..101d2460 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/__base/text_embedding_model.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/__base/text_embedding_model.py
@@ -1,6 +1,6 @@
import time
from abc import abstractmethod
-from typing import Optional
+from typing import List, Optional
from model_providers.core.model_runtime.entities.model_entities import (
ModelPropertyKey,
@@ -23,7 +23,7 @@ class TextEmbeddingModel(AIModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -62,7 +62,7 @@ class TextEmbeddingModel(AIModel):
raise NotImplementedError
@abstractmethod
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/anthropic/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/anthropic/llm/llm.py
index 032c2757..df3308f5 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/anthropic/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/anthropic/llm/llm.py
@@ -1,7 +1,6 @@
import base64
import mimetypes
-from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import Dict, Generator, List, Optional, Type, Union, cast
import anthropic
import requests
@@ -63,10 +62,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -92,9 +91,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -158,13 +157,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
@@ -203,12 +202,12 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
+ user: Union[str, None] = None,
response_format: str = "JSON",
) -> None:
"""
@@ -251,8 +250,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -297,7 +296,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: Message,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm chat response
@@ -345,7 +344,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[MessageStreamEvent],
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm chat stream response
@@ -424,8 +423,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return credentials_kwargs
def _convert_prompt_messages(
- self, prompt_messages: list[PromptMessage]
- ) -> tuple[str, list[dict]]:
+ self, prompt_messages: List[PromptMessage]
+ ) -> tuple[str, List[dict]]:
"""
Convert prompt messages to dict list and system
"""
@@ -560,7 +559,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt_anthropic(
- self, messages: list[PromptMessage]
+ self, messages: List[PromptMessage]
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
@@ -583,7 +582,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return text.rstrip()
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_common.py
index 6ae57a15..b09f51a2 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_common.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/_common.py
@@ -1,3 +1,5 @@
+from typing import Dict, List, Type
+
import openai
from httpx import Timeout
@@ -29,7 +31,7 @@ class _CommonAzureOpenAI:
return credentials_kwargs
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/llm/llm.py
index c712fd06..5da71d39 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/llm/llm.py
@@ -1,7 +1,6 @@
import copy
import logging
-from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import Generator, List, Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
@@ -60,10 +59,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -102,8 +101,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
model_mode = self._get_ai_model_entity(
credentials.get("base_model_name"), model
@@ -176,9 +175,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -215,7 +214,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Completion,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
assistant_text = response.choices[0].text
@@ -257,7 +256,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[Completion],
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
full_text = ""
for chunk in response:
@@ -321,10 +320,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -381,8 +380,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: ChatCompletion,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> LLMResult:
assistant_message = response.choices[0].message
# assistant_message_tool_calls = assistant_message.tool_calls
@@ -435,8 +434,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> Generator:
index = 0
full_assistant_content = ""
@@ -545,8 +544,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
@staticmethod
def _extract_response_tool_calls(
- response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
- ) -> list[AssistantPromptMessage.ToolCall]:
+ response_tool_calls: List[
+ Union[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
+ ],
+ ) -> List[AssistantPromptMessage.ToolCall]:
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
@@ -566,7 +567,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
@staticmethod
def _extract_response_function_call(
- response_function_call: FunctionCall | ChoiceDeltaFunctionCall,
+ response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall],
) -> AssistantPromptMessage.ToolCall:
tool_call = None
if response_function_call:
@@ -651,7 +652,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
self,
credentials: dict,
text: str,
- tools: Optional[list[PromptMessageTool]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
try:
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
@@ -668,8 +669,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _num_tokens_from_messages(
self,
credentials: dict,
- messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
@@ -743,7 +744,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
@staticmethod
def _num_tokens_for_tools(
- encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
+ encoding: tiktoken.Encoding, tools: List[PromptMessageTool]
) -> int:
num_tokens = 0
for tool in tools:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py
index 17e442fe..22a2ba5a 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py
@@ -1,7 +1,7 @@
import base64
import copy
import time
-from typing import Optional, Union
+from typing import List, Optional, Union
import numpy as np
import tiktoken
@@ -35,7 +35,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
base_model_name = credentials["base_model_name"]
@@ -51,7 +51,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
- embeddings: list[list[float]] = [[] for _ in range(len(texts))]
+ embeddings: List[List[float]] = [[] for _ in range(len(texts))]
tokens = []
indices = []
used_tokens = 0
@@ -81,8 +81,8 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
- results: list[list[list[float]]] = [[] for _ in range(len(texts))]
- num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
+ results: List[List[list[float]]] = [[] for _ in range(len(texts))]
+ num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
@@ -112,7 +112,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
embeddings=embeddings, usage=usage, model=base_model_name
)
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
if len(texts) == 0:
return 0
@@ -168,9 +168,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
def _embedding_invoke(
model: str,
client: AzureOpenAI,
- texts: Union[list[str], str],
+ texts: Union[List[str], str],
extra_model_kwargs: dict,
- ) -> tuple[list[list[float]], int]:
+ ) -> tuple[List[list[float]], int]:
response = client.embeddings.create(
input=texts,
model=model,
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py
index 9fcffccc..94d7a392 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py
@@ -1,8 +1,7 @@
-from collections.abc import Generator
from enum import Enum
from hashlib import md5
from json import dumps, loads
-from typing import Any, Union
+from typing import Any, Dict, Generator, List, Union
from requests import post
@@ -25,10 +24,10 @@ class BaichuanMessage:
role: str = Role.USER.value
content: str
- usage: dict[str, int] = None
+ usage: Dict[str, int] = None
stop_reason: str = ""
- def to_dict(self) -> dict[str, Any]:
+ def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"content": self.content,
@@ -112,9 +111,9 @@ class BaichuanModel:
self,
model: str,
stream: bool,
- messages: list[BaichuanMessage],
- parameters: dict[str, Any],
- ) -> dict[str, Any]:
+ messages: List[BaichuanMessage],
+ parameters: Dict[str, Any],
+ ) -> Dict[str, Any]:
if (
model == "baichuan2-turbo"
or model == "baichuan2-turbo-192k"
@@ -165,7 +164,7 @@ class BaichuanModel:
else:
raise BadRequestError(f"Unknown model: {model}")
- def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
+ def _build_headers(self, model: str, data: Dict[str, Any]) -> Dict[str, Any]:
if (
model == "baichuan2-turbo"
or model == "baichuan2-turbo-192k"
@@ -187,8 +186,8 @@ class BaichuanModel:
self,
model: str,
stream: bool,
- messages: list[BaichuanMessage],
- parameters: dict[str, Any],
+ messages: List[BaichuanMessage],
+ parameters: Dict[str, Any],
timeout: int,
) -> Union[Generator, BaichuanMessage]:
if (
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/llm.py
index b1691b5f..d2c116c0 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/llm/llm.py
@@ -1,5 +1,4 @@
-from collections.abc import Generator
-from typing import cast
+from typing import Dict, Generator, List, Type, Union, cast
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@@ -49,13 +48,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
return self._generate(
model=model,
credentials=credentials,
@@ -71,14 +70,14 @@ class BaichuanLarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool] | None = None,
+ prompt_messages: List[PromptMessage],
+ tools: Union[List[PromptMessageTool], None] = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(
self,
- messages: list[PromptMessage],
+ messages: List[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model"""
@@ -149,13 +148,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
if tools is not None and len(tools) > 0:
raise InvokeBadRequestError("Baichuan model doesn't support tools")
@@ -195,7 +194,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: BaichuanMessage,
) -> LLMResult:
@@ -216,7 +215,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: Generator[BaichuanMessage, None, None],
) -> Generator:
@@ -258,7 +257,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
)
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
index 0f008858..e2a6d0f9 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
@@ -47,7 +47,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -93,8 +93,8 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
return result
def embedding(
- self, model: str, api_key, texts: list[str], user: Optional[str] = None
- ) -> tuple[list[list[float]], int]:
+ self, model: str, api_key, texts: List[str], user: Optional[str] = None
+ ) -> tuple[List[list[float]], int]:
"""
Embed given texts
@@ -154,7 +154,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
return [data["embedding"] for data in embeddings], usage["total_tokens"]
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -183,7 +183,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid api key")
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [InternalServerError],
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/bedrock/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/bedrock/llm/llm.py
index 48a9e990..5741efcd 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/bedrock/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/bedrock/llm/llm.py
@@ -1,7 +1,6 @@
import json
import logging
-from collections.abc import Generator
-from typing import Optional, Union
+from typing import Dict, Generator, List, Optional, Type, Union
import boto3
from botocore.config import Config
@@ -48,10 +47,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -77,8 +76,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- messages: list[PromptMessage] | str,
- tools: Optional[list[PromptMessageTool]] = None,
+ messages: Union[List[PromptMessage], str],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -99,7 +98,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(
- self, model_prefix: str, messages: list[PromptMessage]
+ self, model_prefix: str, messages: List[PromptMessage]
) -> str:
"""
Format a list of messages into a full prompt for the Google model
@@ -190,7 +189,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_text
def _convert_messages_to_prompt(
- self, messages: list[PromptMessage], model_prefix: str
+ self, messages: List[PromptMessage], model_prefix: str
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
@@ -216,9 +215,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
def _create_payload(
self,
model_prefix: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
):
"""
@@ -282,9 +281,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -356,7 +355,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@@ -436,7 +435,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@@ -551,7 +550,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
)
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
@@ -570,7 +569,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
def _map_client_to_invoke_error(
self, error_code: str, error_msg: str
- ) -> type[InvokeError]:
+ ) -> Type[InvokeError]:
"""
Map client error to invoke error
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/_assets/icon_l_en.svg b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/_assets/icon_l_en.svg
deleted file mode 100644
index a824d43d..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/_assets/icon_l_en.svg
+++ /dev/null
@@ -1 +0,0 @@
-
\ No newline at end of file
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/_assets/icon_s_en.svg b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/_assets/icon_s_en.svg
deleted file mode 100644
index 466b4fce..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/_assets/icon_s_en.svg
+++ /dev/null
@@ -1,9 +0,0 @@
-
\ No newline at end of file
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.py b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.py
deleted file mode 100644
index f0bd8825..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import logging
-
-from model_providers.core.model_runtime.entities.model_entities import ModelType
-from model_providers.core.model_runtime.errors.validate import (
- CredentialsValidateFailedError,
-)
-from model_providers.core.model_runtime.model_providers.__base.model_provider import (
- ModelProvider,
-)
-
-logger = logging.getLogger(__name__)
-
-
-class ChatGLMProvider(ModelProvider):
- def validate_provider_credentials(self, credentials: dict) -> None:
- """
- Validate provider credentials
-
- if validate failed, raise exception
-
- :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
- """
- try:
- model_instance = self.get_model_instance(ModelType.LLM)
-
- # Use `chatglm3-6b` model for validate,
- model_instance.validate_credentials(
- model="chatglm3-6b", credentials=credentials
- )
- except CredentialsValidateFailedError as ex:
- raise ex
- except Exception as ex:
- logger.exception(
- f"{self.get_provider_schema().provider} credentials validate failed"
- )
- raise ex
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.yaml b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.yaml
deleted file mode 100644
index 0c1688c3..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/chatglm.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-provider: chatglm
-label:
- en_US: ChatGLM
-icon_small:
- en_US: icon_s_en.svg
-icon_large:
- en_US: icon_l_en.svg
-background: "#F4F7FF"
-help:
- title:
- en_US: Deploy ChatGLM to your local
- zh_Hans: 部署您的本地 ChatGLM
- url:
- en_US: https://github.com/THUDM/ChatGLM3
-supported_model_types:
- - llm
-configurate_methods:
- - predefined-model
-provider_credential_schema:
- credential_form_schemas:
- - variable: api_base
- label:
- en_US: API URL
- type: text-input
- required: true
- placeholder:
- zh_Hans: 在此输入您的 API URL
- en_US: Enter your API URL
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm2-6b-32k.yaml b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm2-6b-32k.yaml
deleted file mode 100644
index d1075d74..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm2-6b-32k.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-model: chatglm2-6b-32k
-label:
- en_US: ChatGLM2-6B-32K
-model_type: llm
-features:
- - agent-thought
-model_properties:
- mode: chat
- context_size: 32000
-parameter_rules:
- - name: temperature
- use_template: temperature
- - name: top_p
- use_template: top_p
- required: false
- - name: max_tokens
- use_template: max_tokens
- required: true
- default: 2000
- min: 1
- max: 32000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm2-6b.yaml b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm2-6b.yaml
deleted file mode 100644
index e3cfeb90..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm2-6b.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-model: chatglm2-6b
-label:
- en_US: ChatGLM2-6B
-model_type: llm
-features:
- - agent-thought
-model_properties:
- mode: chat
- context_size: 2000
-parameter_rules:
- - name: temperature
- use_template: temperature
- - name: top_p
- use_template: top_p
- required: false
- - name: max_tokens
- use_template: max_tokens
- required: true
- default: 256
- min: 1
- max: 2000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm3-6b-32k.yaml b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm3-6b-32k.yaml
deleted file mode 100644
index 6f347435..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm3-6b-32k.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-model: chatglm3-6b-32k
-label:
- en_US: ChatGLM3-6B-32K
-model_type: llm
-features:
- - tool-call
- - agent-thought
-model_properties:
- mode: chat
- context_size: 32000
-parameter_rules:
- - name: temperature
- use_template: temperature
- - name: top_p
- use_template: top_p
- required: false
- - name: max_tokens
- use_template: max_tokens
- required: true
- default: 8000
- min: 1
- max: 32000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm3-6b.yaml b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm3-6b.yaml
deleted file mode 100644
index d6d87e2e..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/chatglm3-6b.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-model: chatglm3-6b
-label:
- en_US: ChatGLM3-6B
-model_type: llm
-features:
- - tool-call
- - agent-thought
-model_properties:
- mode: chat
- context_size: 8000
-parameter_rules:
- - name: temperature
- use_template: temperature
- - name: top_p
- use_template: top_p
- required: false
- - name: max_tokens
- use_template: max_tokens
- required: true
- default: 256
- min: 1
- max: 8000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/llm.py
deleted file mode 100644
index 1f798a26..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/chatglm/llm/llm.py
+++ /dev/null
@@ -1,555 +0,0 @@
-import logging
-from collections.abc import Generator
-from os.path import join
-from typing import Optional, cast
-
-from httpx import Timeout
-from openai import (
- APIConnectionError,
- APITimeoutError,
- AuthenticationError,
- ConflictError,
- InternalServerError,
- NotFoundError,
- OpenAI,
- PermissionDeniedError,
- RateLimitError,
- Stream,
- UnprocessableEntityError,
-)
-from openai.types.chat import ChatCompletion, ChatCompletionChunk
-from openai.types.chat.chat_completion_message import FunctionCall
-
-from model_providers.core.model_runtime.entities.llm_entities import (
- LLMResult,
- LLMResultChunk,
- LLMResultChunkDelta,
-)
-from model_providers.core.model_runtime.entities.message_entities import (
- AssistantPromptMessage,
- PromptMessage,
- PromptMessageTool,
- SystemPromptMessage,
- ToolPromptMessage,
- UserPromptMessage,
-)
-from model_providers.core.model_runtime.errors.invoke import (
- InvokeAuthorizationError,
- InvokeBadRequestError,
- InvokeConnectionError,
- InvokeError,
- InvokeRateLimitError,
- InvokeServerUnavailableError,
-)
-from model_providers.core.model_runtime.errors.validate import (
- CredentialsValidateFailedError,
-)
-from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
- LargeLanguageModel,
-)
-from model_providers.core.model_runtime.utils import helper
-
-logger = logging.getLogger(__name__)
-
-
-class ChatGLMLargeLanguageModel(LargeLanguageModel):
- def _invoke(
- self,
- model: str,
- credentials: dict,
- prompt_messages: list[PromptMessage],
- model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
- stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
- """
- Invoke large language model
-
- :param model: model name
- :param credentials: model credentials
- :param prompt_messages: prompt messages
- :param model_parameters: model parameters
- :param tools: tools for tool calling
- :param stop: stop words
- :param stream: is stream response
- :param user: unique user id
- :return: full response or stream response chunk generator result
- """
- # invoke model
- return self._generate(
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
-
- def get_num_tokens(
- self,
- model: str,
- credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool] | None = None,
- ) -> int:
- """
- Get number of tokens for given prompt messages
-
- :param model: model name
- :param credentials: model credentials
- :param prompt_messages: prompt messages
- :param tools: tools for tool calling
- :return:
- """
- return self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
-
- def validate_credentials(self, model: str, credentials: dict) -> None:
- """
- Validate model credentials
-
- :param model: model name
- :param credentials: model credentials
- :return:
- """
- try:
- self._invoke(
- model=model,
- credentials=credentials,
- prompt_messages=[
- UserPromptMessage(content="ping"),
- ],
- model_parameters={
- "max_tokens": 16,
- },
- )
- except Exception as e:
- raise CredentialsValidateFailedError(str(e))
-
- @property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
- """
- Map model invoke error to unified error
- The key is the error type thrown to the caller
- The value is the error type thrown by the model,
- which needs to be converted into a unified error type for the caller.
-
- :return: Invoke error mapping
- """
- return {
- InvokeConnectionError: [
- APIConnectionError,
- APITimeoutError,
- ],
- InvokeServerUnavailableError: [
- InternalServerError,
- ConflictError,
- NotFoundError,
- UnprocessableEntityError,
- PermissionDeniedError,
- ],
- InvokeRateLimitError: [RateLimitError],
- InvokeAuthorizationError: [AuthenticationError],
- InvokeBadRequestError: [ValueError],
- }
-
- def _generate(
- self,
- model: str,
- credentials: dict,
- prompt_messages: list[PromptMessage],
- model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
- stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
- """
- Invoke large language model
-
- :param model: model name
- :param credentials: credentials kwargs
- :param prompt_messages: prompt messages
- :param model_parameters: model parameters
- :param stop: stop words
- :param stream: is stream response
- :param user: unique user id
- :return: full response or stream response chunk generator result
- """
-
- self._check_chatglm_parameters(
- model=model, model_parameters=model_parameters, tools=tools
- )
-
- kwargs = self._to_client_kwargs(credentials)
- # init model client
- client = OpenAI(**kwargs)
-
- extra_model_kwargs = {}
- if stop:
- extra_model_kwargs["stop"] = stop
-
- if user:
- extra_model_kwargs["user"] = user
-
- if tools and len(tools) > 0:
- extra_model_kwargs["functions"] = [
- helper.dump_model(tool) for tool in tools
- ]
-
- result = client.chat.completions.create(
- messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
- model=model,
- stream=stream,
- **model_parameters,
- **extra_model_kwargs,
- )
-
- if stream:
- return self._handle_chat_generate_stream_response(
- model=model,
- credentials=credentials,
- response=result,
- tools=tools,
- prompt_messages=prompt_messages,
- )
-
- return self._handle_chat_generate_response(
- model=model,
- credentials=credentials,
- response=result,
- tools=tools,
- prompt_messages=prompt_messages,
- )
-
- def _check_chatglm_parameters(
- self, model: str, model_parameters: dict, tools: list[PromptMessageTool]
- ) -> None:
- if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0:
- raise InvokeBadRequestError("ChatGLM2 does not support function calling")
-
- def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
- """
- Convert PromptMessage to dict for OpenAI Compatibility API
- """
- if isinstance(message, UserPromptMessage):
- message = cast(UserPromptMessage, message)
- if isinstance(message.content, str):
- message_dict = {"role": "user", "content": message.content}
- else:
- raise ValueError("User message content must be str")
- elif isinstance(message, AssistantPromptMessage):
- message = cast(AssistantPromptMessage, message)
- message_dict = {"role": "assistant", "content": message.content}
- if message.tool_calls and len(message.tool_calls) > 0:
- message_dict["function_call"] = {
- "name": message.tool_calls[0].function.name,
- "arguments": message.tool_calls[0].function.arguments,
- }
- elif isinstance(message, SystemPromptMessage):
- message = cast(SystemPromptMessage, message)
- message_dict = {"role": "system", "content": message.content}
- elif isinstance(message, ToolPromptMessage):
- # check if last message is user message
- message = cast(ToolPromptMessage, message)
- message_dict = {"role": "function", "content": message.content}
- else:
- raise ValueError(f"Unknown message type {type(message)}")
-
- return message_dict
-
- def _extract_response_tool_calls(
- self, response_function_calls: list[FunctionCall]
- ) -> list[AssistantPromptMessage.ToolCall]:
- """
- Extract tool calls from response
-
- :param response_tool_calls: response tool calls
- :return: list of tool calls
- """
- tool_calls = []
- if response_function_calls:
- for response_tool_call in response_function_calls:
- function = AssistantPromptMessage.ToolCall.ToolCallFunction(
- name=response_tool_call.name, arguments=response_tool_call.arguments
- )
-
- tool_call = AssistantPromptMessage.ToolCall(
- id=0, type="function", function=function
- )
- tool_calls.append(tool_call)
-
- return tool_calls
-
- def _to_client_kwargs(self, credentials: dict) -> dict:
- """
- Convert invoke kwargs to client kwargs
-
- :param stream: is stream response
- :param model_name: model name
- :param credentials: credentials dict
- :param model_parameters: model parameters
- :return: client kwargs
- """
- client_kwargs = {
- "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
- "api_key": "1",
- "base_url": join(credentials["api_base"], "v1"),
- }
-
- return client_kwargs
-
- def _handle_chat_generate_stream_response(
- self,
- model: str,
- credentials: dict,
- response: Stream[ChatCompletionChunk],
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
- ) -> Generator:
- full_response = ""
-
- for chunk in response:
- if len(chunk.choices) == 0:
- continue
-
- delta = chunk.choices[0]
-
- if delta.finish_reason is None and (
- delta.delta.content is None or delta.delta.content == ""
- ):
- continue
-
- # check if there is a tool call in the response
- function_calls = None
- if delta.delta.function_call:
- function_calls = [delta.delta.function_call]
-
- assistant_message_tool_calls = self._extract_response_tool_calls(
- function_calls if function_calls else []
- )
-
- # transform assistant message to prompt message
- assistant_prompt_message = AssistantPromptMessage(
- content=delta.delta.content if delta.delta.content else "",
- tool_calls=assistant_message_tool_calls,
- )
-
- if delta.finish_reason is not None:
- # temp_assistant_prompt_message is used to calculate usage
- temp_assistant_prompt_message = AssistantPromptMessage(
- content=full_response, tool_calls=assistant_message_tool_calls
- )
-
- prompt_tokens = self._num_tokens_from_messages(
- messages=prompt_messages, tools=tools
- )
- completion_tokens = self._num_tokens_from_messages(
- messages=[temp_assistant_prompt_message], tools=[]
- )
-
- usage = self._calc_response_usage(
- model=model,
- credentials=credentials,
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- )
-
- yield LLMResultChunk(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=chunk.system_fingerprint,
- delta=LLMResultChunkDelta(
- index=delta.index,
- message=assistant_prompt_message,
- finish_reason=delta.finish_reason,
- usage=usage,
- ),
- )
- else:
- yield LLMResultChunk(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=chunk.system_fingerprint,
- delta=LLMResultChunkDelta(
- index=delta.index,
- message=assistant_prompt_message,
- ),
- )
-
- full_response += delta.delta.content
-
- def _handle_chat_generate_response(
- self,
- model: str,
- credentials: dict,
- response: ChatCompletion,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
- ) -> LLMResult:
- """
- Handle llm chat response
-
- :param model: model name
- :param credentials: credentials
- :param response: response
- :param prompt_messages: prompt messages
- :param tools: tools for tool calling
- :return: llm response
- """
- if len(response.choices) == 0:
- raise InvokeServerUnavailableError("Empty response")
- assistant_message = response.choices[0].message
-
- # convert function call to tool call
- function_calls = assistant_message.function_call
- tool_calls = self._extract_response_tool_calls(
- [function_calls] if function_calls else []
- )
-
- # transform assistant message to prompt message
- assistant_prompt_message = AssistantPromptMessage(
- content=assistant_message.content, tool_calls=tool_calls
- )
-
- prompt_tokens = self._num_tokens_from_messages(
- messages=prompt_messages, tools=tools
- )
- completion_tokens = self._num_tokens_from_messages(
- messages=[assistant_prompt_message], tools=tools
- )
-
- usage = self._calc_response_usage(
- model=model,
- credentials=credentials,
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- )
-
- response = LLMResult(
- model=model,
- prompt_messages=prompt_messages,
- system_fingerprint=response.system_fingerprint,
- usage=usage,
- message=assistant_prompt_message,
- )
-
- return response
-
- def _num_tokens_from_string(
- self, text: str, tools: Optional[list[PromptMessageTool]] = None
- ) -> int:
- """
- Calculate num tokens for text completion model with tiktoken package.
-
- :param model: model name
- :param text: prompt text
- :param tools: tools for tool calling
- :return: number of tokens
- """
- num_tokens = self._get_num_tokens_by_gpt2(text)
-
- if tools:
- num_tokens += self._num_tokens_for_tools(tools)
-
- return num_tokens
-
- def _num_tokens_from_messages(
- self,
- messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
- ) -> int:
- """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer.
-
- it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer,
- As a temporary solution we use GPT2 tokenizer instead.
-
- """
-
- def tokens(text: str):
- return self._get_num_tokens_by_gpt2(text)
-
- tokens_per_message = 3
- tokens_per_name = 1
- num_tokens = 0
- messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
- for message in messages_dict:
- num_tokens += tokens_per_message
- for key, value in message.items():
- if isinstance(value, list):
- text = ""
- for item in value:
- if isinstance(item, dict) and item["type"] == "text":
- text += item["text"]
- value = text
-
- if key == "function_call":
- for t_key, t_value in value.items():
- num_tokens += tokens(t_key)
- if t_key == "function":
- for f_key, f_value in t_value.items():
- num_tokens += tokens(f_key)
- num_tokens += tokens(f_value)
- else:
- num_tokens += tokens(t_key)
- num_tokens += tokens(t_value)
- else:
- num_tokens += tokens(str(value))
-
- if key == "name":
- num_tokens += tokens_per_name
-
- # every reply is primed with assistant
- num_tokens += 3
-
- if tools:
- num_tokens += self._num_tokens_for_tools(tools)
-
- return num_tokens
-
- def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
- """
- Calculate num tokens for tool calling
-
- :param encoding: encoding
- :param tools: tools for tool calling
- :return: number of tokens
- """
-
- def tokens(text: str):
- return self._get_num_tokens_by_gpt2(text)
-
- num_tokens = 0
- for tool in tools:
- # calculate num tokens for function object
- num_tokens += tokens("name")
- num_tokens += tokens(tool.name)
- num_tokens += tokens("description")
- num_tokens += tokens(tool.description)
- parameters = tool.parameters
- num_tokens += tokens("parameters")
- num_tokens += tokens("type")
- num_tokens += tokens(parameters.get("type"))
- if "properties" in parameters:
- num_tokens += tokens("properties")
- for key, value in parameters.get("properties").items():
- num_tokens += tokens(key)
- for field_key, field_value in value.items():
- num_tokens += tokens(field_key)
- if field_key == "enum":
- for enum_field in field_value:
- num_tokens += 3
- num_tokens += tokens(enum_field)
- else:
- num_tokens += tokens(field_key)
- num_tokens += tokens(str(field_value))
- if "required" in parameters:
- num_tokens += tokens("required")
- for required_field in parameters["required"]:
- num_tokens += 3
- num_tokens += tokens(required_field)
-
- return num_tokens
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py
index 620a5b91..fde1a2a1 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/cohere/llm/llm.py
@@ -1,6 +1,6 @@
import logging
from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import Dict, List, Optional, Type, Union, cast
import cohere
from cohere.responses import Chat, Generations
@@ -55,10 +55,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -103,8 +103,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -171,9 +171,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -216,7 +216,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: Generations,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@@ -256,7 +256,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: StreamingGenerations,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@@ -317,9 +317,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -377,8 +377,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: Chat,
- prompt_messages: list[PromptMessage],
- stop: Optional[list[str]] = None,
+ prompt_messages: List[PromptMessage],
+ stop: Optional[List[str]] = None,
) -> LLMResult:
"""
Handle llm chat response
@@ -429,8 +429,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: StreamingChat,
- prompt_messages: list[PromptMessage],
- stop: Optional[list[str]] = None,
+ prompt_messages: List[PromptMessage],
+ stop: Optional[List[str]] = None,
) -> Generator:
"""
Handle llm chat stream response
@@ -517,8 +517,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
index += 1
def _convert_prompt_messages_to_message_and_chat_histories(
- self, prompt_messages: list[PromptMessage]
- ) -> tuple[str, list[dict]]:
+ self, prompt_messages: List[PromptMessage]
+ ) -> tuple[str, List[dict]]:
"""
Convert prompt messages to message and chat histories
:param prompt_messages: prompt messages
@@ -586,7 +586,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return response.length
def _num_tokens_from_messages(
- self, model: str, credentials: dict, messages: list[PromptMessage]
+ self, model: str, credentials: dict, messages: List[PromptMessage]
) -> int:
"""Calculate num tokens Cohere model."""
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
@@ -650,7 +650,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return entity
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/cohere/rerank/rerank.py b/model-providers/model_providers/core/model_runtime/model_providers/cohere/rerank/rerank.py
index 86d0b22f..4cf159c1 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/cohere/rerank/rerank.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/cohere/rerank/rerank.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Dict, List, Optional, Type
import cohere
@@ -32,7 +32,7 @@ class CohereRerankModel(RerankModel):
model: str,
credentials: dict,
query: str,
- docs: list[str],
+ docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
@@ -99,7 +99,7 @@ class CohereRerankModel(RerankModel):
raise CredentialsValidateFailedError(str(ex))
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
index bf4821e9..49044053 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
@@ -35,7 +35,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -51,7 +51,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
- embeddings: list[list[float]] = [[] for _ in range(len(texts))]
+ embeddings: List[List[float]] = [[] for _ in range(len(texts))]
tokens = []
indices = []
used_tokens = 0
@@ -79,8 +79,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
- results: list[list[list[float]]] = [[] for _ in range(len(texts))]
- num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
+ results: List[List[list[float]]] = [[] for _ in range(len(texts))]
+ num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
@@ -105,7 +105,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -161,8 +161,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError(str(ex))
def _embedding_invoke(
- self, model: str, credentials: dict, texts: list[str]
- ) -> tuple[list[list[float]], int]:
+ self, model: str, credentials: dict, texts: List[str]
+ ) -> tuple[List[list[float]], int]:
"""
Invoke embedding model
@@ -216,7 +216,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
return usage
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py
index a23df8aa..36b35ca6 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/google/llm/llm.py
@@ -1,7 +1,6 @@
import json
import logging
-from collections.abc import Generator
-from typing import Optional, Union
+from typing import Dict, Generator, List, Optional, Type, Union
import google.api_core.exceptions as exceptions
import google.generativeai as genai
@@ -66,10 +65,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -102,8 +101,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -118,7 +117,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return self._get_num_tokens_by_gpt2(prompt)
- def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+ def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Google model
@@ -155,10 +154,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -249,7 +248,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: GenerateContentResponse,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@@ -306,7 +305,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
response: GenerateContentResponse,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@@ -416,7 +415,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return glm_content
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
@@ -472,8 +471,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
tool_call = None
if response_function_call:
- from google.protobuf import json_format
-
if isinstance(response_function_call, FunctionCall):
map_composite_dict = dict(response_function_call.args.items())
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/groq/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/groq/llm/llm.py
index 58a76581..e0ca7ade 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/groq/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/groq/llm/llm.py
@@ -1,5 +1,4 @@
-from collections.abc import Generator
-from typing import Optional, Union
+from typing import Generator, List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (
@@ -16,10 +15,10 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/_common.py
index e14f2653..bc87e2ce 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/_common.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/_common.py
@@ -1,3 +1,5 @@
+from typing import Dict, List, Type
+
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
from model_providers.core.model_runtime.errors.invoke import (
@@ -8,5 +10,5 @@ from model_providers.core.model_runtime.errors.invoke import (
class _CommonHuggingfaceHub:
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]}
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
index d47f0461..585a19d5 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
@@ -1,5 +1,4 @@
-from collections.abc import Generator
-from typing import Optional, Union
+from typing import Generator, List, Optional, Union
from huggingface_hub import InferenceClient
from huggingface_hub.hf_api import HfApi
@@ -44,10 +43,10 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -81,8 +80,8 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
@@ -161,7 +160,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
return entity
@staticmethod
- def _get_customizable_model_parameter_rules() -> list[ParameterRule]:
+ def _get_customizable_model_parameter_rules() -> List[ParameterRule]:
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(
DefaultParameterName.TEMPERATURE
).copy()
@@ -253,7 +252,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
response: Generator,
) -> Generator:
index = -1
@@ -300,7 +299,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
response: any,
) -> LLMResult:
if isinstance(response, str):
@@ -355,7 +354,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
return model_info.pipeline_tag
- def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+ def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
messages = messages.copy() # don't mutate the original list
text = "".join(
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py
index a0451017..8afaff1e 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py
@@ -1,6 +1,6 @@
import json
import time
-from typing import Optional
+from typing import List, Optional
import numpy as np
import requests
@@ -35,7 +35,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
@@ -62,7 +62,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
embeddings=self._mean_pooling(embeddings), usage=usage, model=model
)
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
@@ -132,12 +132,12 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
return entity
# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
- # Returned values are a list of floats, or a list[list[floats]]
+ # Returned values are a list of floats, or aList[List[floats]]
# (depending on if you sent a string or a list of string,
# and if the automatic reduction, usually mean_pooling for instance was applied for you or not.
# This should be explained on the model's README.)
@staticmethod
- def _mean_pooling(embeddings: list) -> list[float]:
+ def _mean_pooling(embeddings: list) -> List[float]:
# If automatic reduction by giving model, no need to mean_pooling.
# For example one: List[List[float]]
if not isinstance(embeddings[0][0], list):
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/jina/rerank/rerank.py b/model-providers/model_providers/core/model_runtime/model_providers/jina/rerank/rerank.py
index 09a3b2fa..90147dea 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/jina/rerank/rerank.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/jina/rerank/rerank.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Dict, List, Optional, Type
import httpx
@@ -32,7 +32,7 @@ class JinaRerankModel(RerankModel):
model: str,
credentials: dict,
query: str,
- docs: list[str],
+ docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
@@ -108,7 +108,7 @@ class JinaRerankModel(RerankModel):
raise CredentialsValidateFailedError(str(ex))
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
"""
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py
index 48d8fecb..52a95881 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py
@@ -1,6 +1,6 @@
import time
from json import JSONDecodeError, dumps
-from typing import Optional
+from typing import Dict, List, Optional, Type
from requests import post
@@ -34,7 +34,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
"""
api_base: str = "https://api.jina.ai/v1/embeddings"
- models: list[str] = [
+ models: List[str] = [
"jina-embeddings-v2-base-en",
"jina-embeddings-v2-small-en",
"jina-embeddings-v2-base-zh",
@@ -45,7 +45,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -113,7 +113,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
return result
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -142,7 +142,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid api key")
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/localai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/localai/llm/llm.py
index 0545b52d..b359df99 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/localai/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/localai/llm/llm.py
@@ -1,5 +1,4 @@
-from collections.abc import Generator
-from typing import cast
+from typing import Dict, Generator, List, Type, Union, cast
from httpx import Timeout
from openai import (
@@ -64,13 +63,13 @@ class LocalAILarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
return self._generate(
model=model,
credentials=credentials,
@@ -86,14 +85,14 @@ class LocalAILarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool] | None = None,
+ prompt_messages: List[PromptMessage],
+ tools: Union[List[PromptMessageTool], None] = None,
) -> int:
# tools is not supported yet
return self._num_tokens_from_messages(prompt_messages, tools=tools)
def _num_tokens_from_messages(
- self, messages: list[PromptMessage], tools: list[PromptMessageTool]
+ self, messages: List[PromptMessage], tools: List[PromptMessageTool]
) -> int:
"""
Calculate num tokens for baichuan model
@@ -156,7 +155,7 @@ class LocalAILarguageModel(LargeLanguageModel):
return num_tokens
- def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
+ def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
@@ -224,7 +223,7 @@ class LocalAILarguageModel(LargeLanguageModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
- ) -> AIModelEntity | None:
+ ) -> Union[AIModelEntity, None]:
completion_model = None
if credentials["completion_type"] == "chat_completion":
completion_model = LLMMode.CHAT.value
@@ -286,13 +285,13 @@ class LocalAILarguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
kwargs = self._to_client_kwargs(credentials)
# init model client
client = OpenAI(**kwargs)
@@ -414,7 +413,7 @@ class LocalAILarguageModel(LargeLanguageModel):
return message_dict
def _convert_prompt_message_to_completion_prompts(
- self, messages: list[PromptMessage]
+ self, messages: List[PromptMessage]
) -> str:
"""
Convert PromptMessage to completion prompts
@@ -438,7 +437,7 @@ class LocalAILarguageModel(LargeLanguageModel):
def _handle_completion_generate_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: Completion,
) -> LLMResult:
@@ -489,10 +488,10 @@ class LocalAILarguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: ChatCompletion,
- tools: list[PromptMessageTool],
+ tools: List[PromptMessageTool],
) -> LLMResult:
"""
Handle llm chat response
@@ -547,10 +546,10 @@ class LocalAILarguageModel(LargeLanguageModel):
def _handle_completion_generate_stream_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: Stream[Completion],
- tools: list[PromptMessageTool],
+ tools: List[PromptMessageTool],
) -> Generator:
full_response = ""
@@ -613,10 +612,10 @@ class LocalAILarguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: Stream[ChatCompletionChunk],
- tools: list[PromptMessageTool],
+ tools: List[PromptMessageTool],
) -> Generator:
full_response = ""
@@ -691,8 +690,8 @@ class LocalAILarguageModel(LargeLanguageModel):
full_response += delta.delta.content
def _extract_response_tool_calls(
- self, response_function_calls: list[FunctionCall]
- ) -> list[AssistantPromptMessage.ToolCall]:
+ self, response_function_calls: List[FunctionCall]
+ ) -> List[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@@ -714,7 +713,7 @@ class LocalAILarguageModel(LargeLanguageModel):
return tool_calls
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py
index b42cfb3c..5823ebc9 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py
@@ -1,6 +1,6 @@
import time
from json import JSONDecodeError, dumps
-from typing import Optional
+from typing import Dict, List, Optional, Type, Union
from requests import post
from yarl import URL
@@ -42,7 +42,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -121,7 +121,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
return result
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -138,7 +138,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
def _get_customizable_model_schema(
self, model: str, credentials: dict
- ) -> AIModelEntity | None:
+ ) -> Union[AIModelEntity, None]:
"""
Get customizable model schema
@@ -177,7 +177,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError(f"Invalid credentials: {e}")
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion.py
index 7f2edebc..8a6df376 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion.py
@@ -1,6 +1,5 @@
-from collections.abc import Generator
from json import dumps, loads
-from typing import Any, Union
+from typing import Any, Dict, Generator, List, Union
from requests import Response, post
@@ -27,10 +26,10 @@ class MinimaxChatCompletion:
model: str,
api_key: str,
group_id: str,
- prompt_messages: list[MinimaxMessage],
+ prompt_messages: List[MinimaxMessage],
model_parameters: dict,
- tools: list[dict[str, Any]],
- stop: list[str] | None,
+ tools: List[Dict[str, Any]],
+ stop: Union[List[str], None],
stream: bool,
user: str,
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
index cbe6979f..8643d33e 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
@@ -1,6 +1,5 @@
-from collections.abc import Generator
from json import dumps, loads
-from typing import Any, Union
+from typing import Any, Dict, Generator, List, Union
from requests import Response, post
@@ -28,10 +27,10 @@ class MinimaxChatCompletionPro:
model: str,
api_key: str,
group_id: str,
- prompt_messages: list[MinimaxMessage],
+ prompt_messages: List[MinimaxMessage],
model_parameters: dict,
- tools: list[dict[str, Any]],
- stop: list[str] | None,
+ tools: List[Dict[str, Any]],
+ stop: Union[List[str], None],
stream: bool,
user: str,
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/llm.py
index 1696bc7a..aea0f9e1 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/llm.py
@@ -58,13 +58,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
return self._generate(
model,
credentials,
@@ -110,13 +110,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool] | None = None,
+ prompt_messages: List[PromptMessage],
+ tools: Union[List[PromptMessageTool], None] = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages, tools)
def _num_tokens_from_messages(
- self, messages: list[PromptMessage], tools: list[PromptMessageTool]
+ self, messages: List[PromptMessage], tools: List[PromptMessageTool]
) -> int:
"""
Calculate num tokens for minimax model
@@ -137,13 +137,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
"""
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
"""
@@ -227,7 +227,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: MinimaxMessage,
) -> LLMResult:
@@ -250,7 +250,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: Generator[MinimaxMessage, None, None],
) -> Generator[LLMResultChunk, None, None]:
@@ -319,7 +319,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
)
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/types.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/types.py
index 5e9d73dd..a2fc1220 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/types.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/llm/types.py
@@ -11,11 +11,11 @@ class MinimaxMessage:
role: str = Role.USER.value
content: str
- usage: dict[str, int] = None
+ usage: Dict[str, int] = None
stop_reason: str = ""
- function_call: dict[str, Any] = None
+ function_call: Dict[str, Any] = None
- def to_dict(self) -> dict[str, Any]:
+ def to_dict(self) -> Dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
return {
"sender_type": "BOT",
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py
index ff764bb9..9974d585 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py
@@ -1,6 +1,6 @@
import time
from json import dumps
-from typing import Optional
+from typing import Dict, List, Optional, Type
from requests import post
@@ -44,7 +44,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -103,7 +103,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
return result
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -146,7 +146,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
raise InternalServerError(msg)
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/mistralai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/mistralai/llm/llm.py
index 364d4a92..948d308a 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/mistralai/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/mistralai/llm/llm.py
@@ -1,5 +1,4 @@
-from collections.abc import Generator
-from typing import Optional, Union
+from typing import Generator, List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (
@@ -16,10 +15,10 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py
index fbec3157..6adc96e2 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/model_provider_factory.py
@@ -1,7 +1,7 @@
import importlib
import logging
import os
-from typing import Optional, Union
+from typing import Dict, List, Optional, Union
from pydantic import BaseModel
@@ -42,7 +42,7 @@ class ModelProviderExtension(BaseModel):
class ModelProviderFactory:
# init cache provider by default
init_cache: bool = False
- model_provider_extensions: dict[str, ModelProviderExtension] = None
+ model_provider_extensions: Dict[str, ModelProviderExtension] = None
def __init__(self, init_cache: bool = False) -> None:
# for cache in memory
@@ -51,7 +51,7 @@ class ModelProviderFactory:
def get_providers(
self, provider_name: Union[str, set] = ""
- ) -> list[ProviderEntity]:
+ ) -> List[ProviderEntity]:
"""
Get all providers
:return: list of providers
@@ -159,8 +159,8 @@ class ModelProviderFactory:
self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
- provider_configs: Optional[list[ProviderConfig]] = None,
- ) -> list[SimpleProviderEntity]:
+ provider_configs: Optional[List[ProviderConfig]] = None,
+ ) -> List[SimpleProviderEntity]:
"""
Get all models for given model type
@@ -234,7 +234,7 @@ class ModelProviderFactory:
return model_provider_instance
- def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
+ def _get_model_provider_map(self) -> Dict[str, ModelProviderExtension]:
if self.model_provider_extensions:
return self.model_provider_extensions
@@ -254,7 +254,7 @@ class ModelProviderFactory:
position_map = get_position_map(model_providers_path)
# traverse all model_provider_dir_paths
- model_providers: list[ModelProviderExtension] = []
+ model_providers: List[ModelProviderExtension] = []
for model_provider_dir_path in model_provider_dir_paths:
# get model_provider dir name
model_provider_name = os.path.basename(model_provider_dir_path)
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py
index fbf4bf29..8bd0f500 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/moonshot/llm/llm.py
@@ -1,5 +1,5 @@
from collections.abc import Generator
-from typing import Optional, Union
+from typing import List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (
@@ -16,10 +16,10 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py
index 19a7cc58..3ce6f73f 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/llm/llm.py
@@ -3,7 +3,7 @@ import logging
import re
from collections.abc import Generator
from decimal import Decimal
-from typing import Optional, Union, cast
+from typing import Dict, List, Optional, Type, Union, cast
from urllib.parse import urljoin
import requests
@@ -63,10 +63,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -97,8 +97,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -159,9 +159,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -258,7 +258,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict,
completion_type: LLMMode,
response: requests.Response,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm completion response
@@ -310,7 +310,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict,
completion_type: LLMMode,
response: requests.Response,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm completion stream response
@@ -462,7 +462,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return message_dict
- def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
+ def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
"""
Calculate num tokens.
@@ -700,7 +700,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return entity
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
index fd037190..15a8911a 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
@@ -2,7 +2,7 @@ import json
import logging
import time
from decimal import Decimal
-from typing import Optional
+from typing import Dict, List, Optional, Type
from urllib.parse import urljoin
import numpy as np
@@ -48,7 +48,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -123,7 +123,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
embeddings=batched_embeddings, usage=usage, model=model
)
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
@@ -211,7 +211,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
return usage
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/_common.py
index c459c20b..b6f3ee40 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openai/_common.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/_common.py
@@ -1,3 +1,5 @@
+from typing import Dict, List, Type
+
import openai
from httpx import Timeout
@@ -35,7 +37,7 @@ class _CommonOpenAI:
return credentials_kwargs
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py
index 0b886642..230ad76e 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/llm/llm.py
@@ -1,6 +1,6 @@
import logging
from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import List, Optional, Union, cast
import tiktoken
from openai import OpenAI, Stream
@@ -72,10 +72,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -128,13 +128,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
@@ -194,12 +194,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
+ user: Union[str, None] = None,
response_format: str = "JSON",
) -> None:
"""
@@ -242,12 +242,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
+ user: Union[str, None] = None,
response_format: str = "JSON",
) -> None:
"""
@@ -287,8 +287,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -366,7 +366,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
- def remote_models(self, credentials: dict) -> list[AIModelEntity]:
+ def remote_models(self, credentials: dict) -> List[AIModelEntity]:
"""
Return remote models if credentials are provided.
@@ -424,9 +424,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -479,7 +479,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Completion,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm completion response
@@ -528,7 +528,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[Completion],
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm completion stream response
@@ -599,10 +599,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -676,8 +676,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: ChatCompletion,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> LLMResult:
"""
Handle llm chat response
@@ -740,8 +740,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> Generator:
"""
Handle llm chat stream response
@@ -851,8 +851,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
def _extract_response_tool_calls(
self,
- response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
- ) -> list[AssistantPromptMessage.ToolCall]:
+ response_tool_calls: List[
+ Union[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
+ ],
+ ) -> List[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@@ -877,7 +879,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
return tool_calls
def _extract_response_function_call(
- self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
+ self, response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall]
) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
@@ -967,7 +969,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
return message_dict
def _num_tokens_from_string(
- self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None
+ self, model: str, text: str, tools: Optional[List[PromptMessageTool]] = None
) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
@@ -992,8 +994,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
def _num_tokens_from_messages(
self,
model: str,
- messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
@@ -1068,7 +1070,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
return num_tokens
def _num_tokens_for_tools(
- self, encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
+ self, encoding: tiktoken.Encoding, tools: List[PromptMessageTool]
) -> int:
"""
Calculate num tokens for tool calling with tiktoken package.
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/moderation/moderation.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/moderation/moderation.py
index 8ff1958a..3797e815 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openai/moderation/moderation.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/moderation/moderation.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import List, Optional
from openai import OpenAI
from openai.types import ModerationCreateResponse
@@ -82,7 +82,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
raise CredentialsValidateFailedError(str(ex))
def _moderation_invoke(
- self, model: str, client: OpenAI, texts: list[str]
+ self, model: str, client: OpenAI, texts: List[str]
) -> ModerationCreateResponse:
"""
Invoke moderation model
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py
index ef04e13f..f502b544 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py
@@ -1,6 +1,6 @@
import base64
import time
-from typing import Optional, Union
+from typing import List, Optional, Tuple, Union
import numpy as np
import tiktoken
@@ -31,7 +31,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -58,7 +58,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
- embeddings: list[list[float]] = [[] for _ in range(len(texts))]
+ embeddings: List[List[float]] = [[] for _ in range(len(texts))]
tokens = []
indices = []
used_tokens = 0
@@ -89,8 +89,8 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch
- results: list[list[list[float]]] = [[] for _ in range(len(texts))]
- num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
+ results: List[List[List[float]]] = [[] for _ in range(len(texts))]
+ num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
@@ -118,7 +118,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -167,9 +167,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
self,
model: str,
client: OpenAI,
- texts: Union[list[str], str],
+ texts: Union[List[str], str],
extra_model_kwargs: dict,
- ) -> tuple[list[list[float]], int]:
+ ) -> Tuple[List[List[float]], int]:
"""
Invoke embedding model
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/_common.py
index d5c3d879..3abbff45 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/_common.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/_common.py
@@ -1,3 +1,5 @@
+from typing import Dict, List, Type
+
import requests
from model_providers.core.model_runtime.errors.invoke import (
@@ -12,7 +14,7 @@ from model_providers.core.model_runtime.errors.invoke import (
class _CommonOAI_API_Compat:
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
index 05a2c7c9..014008dd 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
@@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Generator
from decimal import Decimal
-from typing import Optional, Union, cast
+from typing import List, Optional, Union, cast
from urllib.parse import urljoin
import requests
@@ -61,10 +61,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -98,8 +98,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -282,10 +282,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -384,7 +384,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model: str,
credentials: dict,
response: requests.Response,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@@ -516,7 +516,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model: str,
credentials: dict,
response: requests.Response,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
response_json = response.json()
@@ -649,7 +649,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return message_dict
def _num_tokens_from_string(
- self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None
+ self, model: str, text: str, tools: Optional[List[PromptMessageTool]] = None
) -> int:
"""
Approximate num tokens for model with gpt2 tokenizer.
@@ -669,8 +669,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
def _num_tokens_from_messages(
self,
model: str,
- messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Approximate num tokens with GPT2 tokenizer.
@@ -722,7 +722,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return num_tokens
- def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
+ def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling with tiktoken package.
@@ -769,8 +769,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return num_tokens
def _extract_response_tool_calls(
- self, response_tool_calls: list[dict]
- ) -> list[AssistantPromptMessage.ToolCall]:
+ self, response_tool_calls: List[dict]
+ ) -> List[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
index 0c71cfa0..fcf13f62 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
@@ -1,7 +1,7 @@
import json
import time
from decimal import Decimal
-from typing import Optional
+from typing import List, Optional
from urllib.parse import urljoin
import numpy as np
@@ -40,7 +40,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -131,7 +131,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
embeddings=batched_embeddings, usage=usage, model=model
)
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/llm.py
index 271eca7e..04c46075 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/llm.py
@@ -1,4 +1,4 @@
-from collections.abc import Generator
+from typing import Dict, Generator, List, Type, Union
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.llm_entities import (
@@ -54,13 +54,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
return self._generate(
model,
credentials,
@@ -105,13 +105,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool] | None = None,
+ prompt_messages: List[PromptMessage],
+ tools: Union[List[PromptMessageTool], None] = None,
) -> int:
return self._num_tokens_from_messages(prompt_messages, tools)
def _num_tokens_from_messages(
- self, messages: list[PromptMessage], tools: list[PromptMessageTool]
+ self, messages: List[PromptMessage], tools: List[PromptMessageTool]
) -> int:
"""
Calculate num tokens for OpenLLM model
@@ -124,13 +124,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
client = OpenLLMGenerate()
response = client.generate(
model_name=model,
@@ -183,7 +183,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: OpenLLMGenerateMessage,
) -> LLMResult:
@@ -206,7 +206,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: Generator[OpenLLMGenerateMessage, None, None],
) -> Generator[LLMResultChunk, None, None]:
@@ -249,7 +249,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
- ) -> AIModelEntity | None:
+ ) -> Union[AIModelEntity, None]:
"""
used to define customizable model schema
"""
@@ -298,7 +298,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
return entity
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
index 05dc9488..1c0ab40d 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
@@ -20,10 +20,10 @@ class OpenLLMGenerateMessage:
role: str = Role.USER.value
content: str
- usage: dict[str, int] = None
+ usage: Dict[str, int] = None
stop_reason: str = ""
- def to_dict(self) -> dict[str, Any]:
+ def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"content": self.content,
@@ -40,9 +40,9 @@ class OpenLLMGenerate:
server_url: str,
model_name: str,
stream: bool,
- model_parameters: dict[str, Any],
- stop: list[str],
- prompt_messages: list[OpenLLMGenerateMessage],
+ model_parameters: Dict[str, Any],
+ stop: List[str],
+ prompt_messages: List[OpenLLMGenerateMessage],
user: str,
) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]:
if not server_url:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py
index 5e1ed3a7..c56442e1 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py
@@ -1,6 +1,6 @@
import time
from json import dumps
-from typing import Optional
+from typing import Dict, List, Optional, Type
from requests import post
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
@@ -35,7 +35,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -89,7 +89,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
return result
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -118,7 +118,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid server_url")
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/replicate/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/replicate/_common.py
index 582cb8aa..6299cfc0 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/replicate/_common.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/replicate/_common.py
@@ -1,3 +1,5 @@
+from typing import Dict, List, Type
+
from replicate.exceptions import ModelError, ReplicateError
from model_providers.core.model_runtime.errors.invoke import (
@@ -8,5 +10,5 @@ from model_providers.core.model_runtime.errors.invoke import (
class _CommonReplicate:
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {InvokeBadRequestError: [ReplicateError, ModelError]}
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/replicate/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/replicate/llm/llm.py
index 987cb4d0..9b72f011 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/replicate/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/replicate/llm/llm.py
@@ -1,5 +1,4 @@
-from collections.abc import Generator
-from typing import Optional, Union
+from typing import Generator, Optional, Union
from replicate import Client as ReplicateClient
from replicate.exceptions import ReplicateError
@@ -43,10 +42,10 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -86,8 +85,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
@@ -167,7 +166,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
@classmethod
def _get_customizable_model_parameter_rules(
cls, model: str, credentials: dict
- ) -> list[ParameterRule]:
+ ) -> List[ParameterRule]:
version = credentials["model_version"]
client = ReplicateClient(
@@ -215,8 +214,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
model: str,
credentials: dict,
prediction: Prediction,
- stop: list[str],
- prompt_messages: list[PromptMessage],
+ stop: List[str],
+ prompt_messages: List[PromptMessage],
) -> Generator:
index = -1
current_completion: str = ""
@@ -281,8 +280,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
model: str,
credentials: dict,
prediction: Prediction,
- stop: list[str],
- prompt_messages: list[PromptMessage],
+ stop: List[str],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
current_completion: str = ""
stop_condition_reached = False
@@ -332,7 +331,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
elif param_type == "string":
return "string"
- def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+ def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
messages = messages.copy() # don't mutate the original list
text = "".join(
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
index a6884360..e45bf19a 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
@@ -1,6 +1,6 @@
import json
import time
-from typing import Optional
+from typing import List, Optional
from replicate import Client as ReplicateClient
@@ -31,7 +31,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
client = ReplicateClient(
@@ -52,7 +52,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
@@ -124,8 +124,8 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
client: ReplicateClient,
replicate_model_version: str,
text_input_key: str,
- texts: list[str],
- ) -> list[list[float]]:
+ texts: List[str],
+ ) -> List[List[float]]:
if text_input_key in ("text", "inputs"):
embeddings = []
for text in texts:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py
index c7ea29f3..02741fd7 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/spark/llm/llm.py
@@ -1,6 +1,6 @@
import threading
from collections.abc import Generator
-from typing import Optional, Union
+from typing import Dict, List, Optional, Type, Union
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@@ -37,10 +37,10 @@ class SparkLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -66,8 +66,8 @@ class SparkLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -109,9 +109,9 @@ class SparkLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -171,13 +171,12 @@ class SparkLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
client: SparkLLMClient,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
:param model: model name
- :param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
@@ -222,7 +221,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
model: str,
credentials: dict,
client: SparkLLMClient,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@@ -300,7 +299,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
return message_text
- def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+ def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
@@ -317,7 +316,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
return text.rstrip()
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py
index c954affd..72b4be3c 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/togetherai/llm/llm.py
@@ -1,5 +1,5 @@
from collections.abc import Generator
-from typing import Optional, Union
+from typing import List, Optional, Union
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import (
@@ -21,10 +21,10 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -50,10 +50,10 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -81,8 +81,8 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/_common.py
index da62624a..b3497560 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/_common.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/_common.py
@@ -1,3 +1,5 @@
+from typing import Dict, List, Type
+
from model_providers.core.model_runtime.errors.invoke import InvokeError
@@ -11,7 +13,7 @@ class _CommonTongyi:
return credentials_kwargs
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/_client.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/_client.py
index d7bf35f3..8aac0c74 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/_client.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/_client.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional
+from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Tongyi
@@ -8,7 +8,7 @@ from langchain.schema import Generation, LLMResult
class EnhanceTongyi(Tongyi):
@property
- def _default_params(self) -> dict[str, Any]:
+ def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params = {"top_p": self.top_p, "api_key": self.dashscope_api_key}
@@ -16,13 +16,13 @@ class EnhanceTongyi(Tongyi):
def _generate(
self,
- prompts: list[str],
- stop: Optional[list[str]] = None,
+ prompts: List[str],
+ stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = []
- params: dict[str, Any] = {
+ params: Dict[str, Any] = {
**{"model": self.model_name},
**self._default_params,
**kwargs,
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py
index f2b0741c..058027cc 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/llm/llm.py
@@ -1,5 +1,5 @@
from collections.abc import Generator
-from typing import Optional, Union
+from typing import Dict, List, Optional, Type, Union
from dashscope import get_tokenizer
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
@@ -50,10 +50,10 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -79,14 +79,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- callbacks: list[Callback] = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ callbacks: List[Callback] = None,
+ ) -> Union[LLMResult, Generator]:
"""
Wrapper for code block mode
"""
@@ -174,8 +174,8 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -220,9 +220,9 @@ if you are not sure about the structure.
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- stop: Optional[list[str]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -289,7 +289,7 @@ if you are not sure about the structure.
model: str,
credentials: dict,
response: DashScopeAPIResponse,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@@ -326,7 +326,7 @@ if you are not sure about the structure.
model: str,
credentials: dict,
responses: Generator,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@@ -412,7 +412,7 @@ if you are not sure about the structure.
return message_text
- def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
+ def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
@@ -429,8 +429,8 @@ if you are not sure about the structure.
return text.rstrip()
def _convert_prompt_messages_to_tongyi_messages(
- self, prompt_messages: list[PromptMessage]
- ) -> list[dict]:
+ self, prompt_messages: List[PromptMessage]
+ ) -> List[dict]:
"""
Convert prompt messages to tongyi messages
@@ -466,7 +466,7 @@ if you are not sure about the structure.
return tongyi_messages
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py
index aa3e7f88..4ab2dd1e 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/tongyi/tts/tts.py
@@ -4,7 +4,7 @@ from io import BytesIO
from typing import Optional
import dashscope
-from fastapi.responses import StreamingResponse
+from fastapi.responses import Response, StreamingResponse
from pydub import AudioSegment
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
index 39464d4e..2b6d82ac 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py
@@ -1,9 +1,8 @@
-from collections.abc import Generator
from datetime import datetime, timedelta
from enum import Enum
from json import dumps, loads
from threading import Lock
-from typing import Any, Union
+from typing import Any, Dict, Generator, List, Union
from requests import Response, post
@@ -19,7 +18,7 @@ from model_providers.core.model_runtime.model_providers.wenxin.llm.ernie_bot_err
)
# map api_key to access_token
-baidu_access_tokens: dict[str, "BaiduAccessToken"] = {}
+baidu_access_tokens: Dict[str, "BaiduAccessToken"] = {}
baidu_access_tokens_lock = Lock()
@@ -118,10 +117,10 @@ class ErnieMessage:
role: str = Role.USER.value
content: str
- usage: dict[str, int] = None
+ usage: Dict[str, int] = None
stop_reason: str = ""
- def to_dict(self) -> dict[str, Any]:
+ def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"content": self.content,
@@ -156,11 +155,11 @@ class ErnieBotModel:
self,
model: str,
stream: bool,
- messages: list[ErnieMessage],
- parameters: dict[str, Any],
+ messages: List[ErnieMessage],
+ parameters: Dict[str, Any],
timeout: int,
- tools: list[PromptMessageTool],
- stop: list[str],
+ tools: List[PromptMessageTool],
+ stop: List[str],
user: str,
) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]:
# check parameters
@@ -243,15 +242,15 @@ class ErnieBotModel:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token
- def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
+ def _copy_messages(self, messages: List[ErnieMessage]) -> List[ErnieMessage]:
return [ErnieMessage(message.content, message.role) for message in messages]
def _check_parameters(
self,
model: str,
- parameters: dict[str, Any],
- tools: list[PromptMessageTool],
- stop: list[str],
+ parameters: Dict[str, Any],
+ tools: List[PromptMessageTool],
+ stop: List[str],
) -> None:
if model not in self.api_bases:
raise BadRequestError(f"Invalid model: {model}")
@@ -276,13 +275,13 @@ class ErnieBotModel:
def _build_request_body(
self,
model: str,
- messages: list[ErnieMessage],
+ messages: List[ErnieMessage],
stream: bool,
- parameters: dict[str, Any],
- tools: list[PromptMessageTool],
- stop: list[str],
+ parameters: Dict[str, Any],
+ tools: List[PromptMessageTool],
+ stop: List[str],
user: str,
- ) -> dict[str, Any]:
+ ) -> Dict[str, Any]:
# if model in self.function_calling_supports:
# return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user)
return self._build_chat_request_body(
@@ -292,13 +291,13 @@ class ErnieBotModel:
def _build_function_calling_request_body(
self,
model: str,
- messages: list[ErnieMessage],
+ messages: List[ErnieMessage],
stream: bool,
- parameters: dict[str, Any],
- tools: list[PromptMessageTool],
- stop: list[str],
+ parameters: Dict[str, Any],
+ tools: List[PromptMessageTool],
+ stop: List[str],
user: str,
- ) -> dict[str, Any]:
+ ) -> Dict[str, Any]:
if len(messages) % 2 == 0:
raise BadRequestError("The number of messages should be odd.")
if messages[0].role == "function":
@@ -311,12 +310,12 @@ class ErnieBotModel:
def _build_chat_request_body(
self,
model: str,
- messages: list[ErnieMessage],
+ messages: List[ErnieMessage],
stream: bool,
- parameters: dict[str, Any],
- stop: list[str],
+ parameters: Dict[str, Any],
+ stop: List[str],
user: str,
- ) -> dict[str, Any]:
+ ) -> Dict[str, Any]:
if len(messages) == 0:
raise BadRequestError("The number of messages should not be zero.")
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/llm.py
index 1f7a638b..f53c9872 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/wenxin/llm/llm.py
@@ -1,5 +1,4 @@
-from collections.abc import Generator
-from typing import Optional, Union, cast
+from typing import Dict, Generator, List, Optional, Type, Union, cast
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import (
@@ -59,13 +58,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
return self._generate(
model=model,
credentials=credentials,
@@ -81,13 +80,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
- callbacks: list[Callback] = None,
+ callbacks: List[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
@@ -140,12 +139,12 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
+ user: Union[str, None] = None,
response_format: str = "JSON",
) -> None:
"""
@@ -187,15 +186,15 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool] | None = None,
+ prompt_messages: List[PromptMessage],
+ tools: Union[List[PromptMessageTool], None] = None,
) -> int:
# tools is not supported yet
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(
self,
- messages: list[PromptMessage],
+ messages: List[PromptMessage],
) -> int:
"""Calculate num tokens for baichuan model"""
@@ -234,13 +233,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
instance = ErnieBotModel(
api_key=credentials["api_key"],
secret_key=credentials["secret_key"],
@@ -304,7 +303,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: ErnieMessage,
) -> LLMResult:
@@ -325,7 +324,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _handle_chat_generate_stream_response(
self,
model: str,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
credentials: dict,
response: Generator[ErnieMessage, None, None],
) -> Generator:
@@ -367,7 +366,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
)
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py
index aa139170..e521c505 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/llm/llm.py
@@ -1,5 +1,5 @@
from collections.abc import Generator, Iterator
-from typing import cast
+from typing import Dict, List, Union, cast
from openai import (
APIConnectionError,
@@ -81,13 +81,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
"""
invoke LLM
@@ -168,8 +168,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool] | None = None,
+ prompt_messages: List[PromptMessage],
+ tools: Union[List[PromptMessageTool], None] = None,
) -> int:
"""
get number of tokens
@@ -181,8 +181,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _num_tokens_from_messages(
self,
- messages: list[PromptMessage],
- tools: list[PromptMessageTool],
+ messages: List[PromptMessage],
+ tools: List[PromptMessageTool],
is_completion_model: bool = False,
) -> int:
def tokens(text: str):
@@ -240,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
return num_tokens
- def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
+ def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling
@@ -284,7 +284,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
return num_tokens
- def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
+ def _convert_prompt_message_to_text(self, message: List[PromptMessage]) -> str:
"""
convert prompt message to text
"""
@@ -337,7 +337,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
- ) -> AIModelEntity | None:
+ ) -> Union[AIModelEntity, None]:
"""
used to define customizable model schema
"""
@@ -412,14 +412,14 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
extra_model_kwargs: XinferenceModelExtraParameter,
- tools: list[PromptMessageTool] | None = None,
- stop: list[str] | None = None,
+ tools: Union[List[PromptMessageTool], None] = None,
+ stop: Union[List[str], None] = None,
stream: bool = True,
- user: str | None = None,
- ) -> LLMResult | Generator:
+ user: Union[str, None] = None,
+ ) -> Union[LLMResult, Generator]:
"""
generate text from LLM
@@ -525,8 +525,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
def _extract_response_tool_calls(
self,
- response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
- ) -> list[AssistantPromptMessage.ToolCall]:
+ response_tool_calls: Union[
+ List[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
+ ],
+ ) -> List[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
@@ -551,7 +553,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
return tool_calls
def _extract_response_function_call(
- self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
+ self, response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall]
) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
@@ -576,8 +578,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool],
+ prompt_messages: List[PromptMessage],
+ tools: List[PromptMessageTool],
resp: ChatCompletion,
) -> LLMResult:
"""
@@ -633,8 +635,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool],
+ prompt_messages: List[PromptMessage],
+ tools: List[PromptMessageTool],
resp: Iterator[ChatCompletionChunk],
) -> Generator:
"""
@@ -721,8 +723,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool],
+ prompt_messages: List[PromptMessage],
+ tools: List[PromptMessageTool],
resp: Completion,
) -> LLMResult:
"""
@@ -765,8 +767,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: list[PromptMessageTool],
+ prompt_messages: List[PromptMessage],
+ tools: List[PromptMessageTool],
resp: Iterator[Completion],
) -> Generator:
"""
@@ -834,7 +836,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
full_response += delta.text
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/rerank/rerank.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/rerank/rerank.py
index 4291d3c1..e35400b7 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/rerank/rerank.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/rerank/rerank.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Dict, List, Optional, Type, Union
from xinference_client.client.restful.restful_client import (
Client,
@@ -41,7 +41,7 @@ class XinferenceRerankModel(RerankModel):
model: str,
credentials: dict,
query: str,
- docs: list[str],
+ docs: List[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
@@ -133,7 +133,7 @@ class XinferenceRerankModel(RerankModel):
raise CredentialsValidateFailedError(str(ex))
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@@ -152,7 +152,7 @@ class XinferenceRerankModel(RerankModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
- ) -> AIModelEntity | None:
+ ) -> Union[AIModelEntity, None]:
"""
used to define customizable model schema
"""
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
index 2cabf59b..cc3d30f0 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
@@ -1,5 +1,5 @@
import time
-from typing import Optional
+from typing import Dict, List, Optional, Type, Union
from xinference_client.client.restful.restful_client import (
Client,
@@ -46,7 +46,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -116,7 +116,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
return result
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -167,7 +167,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError(e)
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
@@ -210,7 +210,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
def get_customizable_model_schema(
self, model: str, credentials: dict
- ) -> AIModelEntity | None:
+ ) -> Union[AIModelEntity, None]:
"""
used to define customizable model schema
"""
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference_helper.py b/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference_helper.py
index 6194a0cb..885ebf47 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference_helper.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/xinference/xinference_helper.py
@@ -1,5 +1,6 @@
from threading import Lock
from time import time
+from typing import List
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
@@ -10,7 +11,7 @@ from yarl import URL
class XinferenceModelExtraParameter:
model_format: str
model_handle_type: str
- model_ability: list[str]
+ model_ability: List[str]
max_tokens: int = 512
context_length: int = 2048
support_function_call: bool = False
@@ -19,7 +20,7 @@ class XinferenceModelExtraParameter:
self,
model_format: str,
model_handle_type: str,
- model_ability: list[str],
+ model_ability: List[str],
support_function_call: bool,
max_tokens: int,
context_length: int,
@@ -115,7 +116,7 @@ class XinferenceHelper:
model_handle_type = "chat"
else:
raise NotImplementedError(
- f"xinference model handle type {model_handle_type} is not supported"
+ f"xinference model handle type {response_json.get('model_type')} is not supported"
)
support_function_call = "tools" in model_ability
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py
index 6f6595ed..f19135de 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/_common.py
@@ -1,3 +1,5 @@
+from typing import Dict, List, Type
+
from model_providers.core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
@@ -22,12 +24,17 @@ class _CommonZhipuaiAI:
else credentials["zhipuai_api_key"]
if "zhipuai_api_key" in credentials
else None,
+ "api_base": credentials["api_base"]
+ if "api_base" in credentials
+ else credentials["zhipuai_api_base"]
+ if "zhipuai_api_base" in credentials
+ else None,
}
return credentials_kwargs
@property
- def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+ def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py
index b3721cd0..b987e6fc 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/llm/llm.py
@@ -1,5 +1,14 @@
-from collections.abc import Generator
-from typing import Optional, Union
+from typing import Generator, List, Optional, Union
+
+from zhipuai import (
+ ZhipuAI,
+)
+from zhipuai.types.chat.chat_completion import (
+ Completion,
+)
+from zhipuai.types.chat.chat_completion_chunk import (
+ ChatCompletionChunk,
+)
from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
@@ -24,15 +33,6 @@ from model_providers.core.model_runtime.model_providers.__base.large_language_mo
from model_providers.core.model_runtime.model_providers.zhipuai._common import (
_CommonZhipuaiAI,
)
-from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import (
- ZhipuAI,
-)
-from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import (
- Completion,
-)
-from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import (
- ChatCompletionChunk,
-)
from model_providers.core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
@@ -53,10 +53,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -91,8 +91,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
)
# def _transform_json_prompts(self, model: str, credentials: dict,
- # prompt_messages: list[PromptMessage], model_parameters: dict,
- # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
+ # prompt_messages:List[PromptMessage], model_parameters: dict,
+ # tools:List[PromptMessageTool] | None = None, stop:List[str] | None = None,
# stream: bool = True, user: str | None = None) \
# -> None:
# """
@@ -126,8 +126,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- prompt_messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ prompt_messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
@@ -172,10 +172,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
self,
model: str,
credentials_kwargs: dict,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
model_parameters: dict,
- tools: Optional[list[PromptMessageTool]] = None,
- stop: Optional[list[str]] = None,
+ tools: Optional[List[PromptMessageTool]] = None,
+ stop: Optional[List[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
@@ -195,7 +195,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if stop:
extra_model_kwargs["stop"] = stop
- client = ZhipuAI(api_key=credentials_kwargs["api_key"])
+ client = ZhipuAI(base_url=credentials_kwargs["api_base"],
+ api_key=credentials_kwargs["api_key"])
if len(prompt_messages) == 0:
raise ValueError("At least one message is required")
@@ -205,7 +206,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
prompt_messages = prompt_messages[1:]
# resolve zhipuai model not support system message and user message, assistant message must be in sequence
- new_prompt_messages: list[PromptMessage] = []
+ new_prompt_messages: List[PromptMessage] = []
for prompt_message in prompt_messages:
copy_prompt_message = prompt_message.copy()
if copy_prompt_message.role in [
@@ -375,9 +376,9 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- tools: Optional[list[PromptMessageTool]],
+ tools: Optional[List[PromptMessageTool]],
response: Completion,
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> LLMResult:
"""
Handle llm response
@@ -388,7 +389,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:return: llm response
"""
text = ""
- assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = []
+ assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
for choice in response.choices:
if choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
@@ -430,9 +431,9 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
self,
model: str,
credentials: dict,
- tools: Optional[list[PromptMessageTool]],
+ tools: Optional[List[PromptMessageTool]],
responses: Generator[ChatCompletionChunk, None, None],
- prompt_messages: list[PromptMessage],
+ prompt_messages: List[PromptMessage],
) -> Generator:
"""
Handle llm stream response
@@ -454,7 +455,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
):
continue
- assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = []
+ assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
for tool_call in delta.delta.tool_calls or []:
if tool_call.type == "function":
assistant_tool_calls.append(
@@ -531,8 +532,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
def _convert_messages_to_prompt(
self,
- messages: list[PromptMessage],
- tools: Optional[list[PromptMessageTool]] = None,
+ messages: List[PromptMessage],
+ tools: Optional[List[PromptMessageTool]] = None,
) -> str:
"""
:param messages: List of PromptMessage to combine.
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
index ca75fe79..9e12b53c 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
+++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py
@@ -1,5 +1,9 @@
import time
-from typing import Optional
+from typing import List, Optional, Tuple
+
+from zhipuai import (
+ ZhipuAI,
+)
from model_providers.core.model_runtime.entities.model_entities import PriceType
from model_providers.core.model_runtime.entities.text_embedding_entities import (
@@ -15,9 +19,6 @@ from model_providers.core.model_runtime.model_providers.__base.text_embedding_mo
from model_providers.core.model_runtime.model_providers.zhipuai._common import (
_CommonZhipuaiAI,
)
-from model_providers.core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import (
- ZhipuAI,
-)
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
@@ -29,7 +30,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
self,
model: str,
credentials: dict,
- texts: list[str],
+ texts: List[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
@@ -42,7 +43,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
:return: embeddings result
"""
credentials_kwargs = self._to_credential_kwargs(credentials)
- client = ZhipuAI(api_key=credentials_kwargs["api_key"])
+ client = ZhipuAI(base_url=credentials_kwargs["api_base"],
+ api_key=credentials_kwargs["api_key"])
embeddings, embedding_used_tokens = self.embed_documents(model, client, texts)
@@ -54,7 +56,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
model=model,
)
- def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+ def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
"""
Get number of tokens for given prompt messages
@@ -83,7 +85,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
try:
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
- client = ZhipuAI(api_key=credentials_kwargs["api_key"])
+ client = ZhipuAI(base_url=credentials_kwargs["api_base"],
+ api_key=credentials_kwargs["api_key"])
# call embedding model
self.embed_documents(
@@ -95,8 +98,8 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
raise CredentialsValidateFailedError(str(ex))
def embed_documents(
- self, model: str, client: ZhipuAI, texts: list[str]
- ) -> tuple[list[list[float]], int]:
+ self, model: str, client: ZhipuAI, texts: List[str]
+ ) -> Tuple[List[List[float]], int]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
@@ -116,17 +119,6 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
- def embed_query(self, text: str) -> list[float]:
- """Call out to ZhipuAI's embedding endpoint.
-
- Args:
- text: The text to embed.
-
- Returns:
- Embeddings for the text.
- """
- return self.embed_documents([text])[0]
-
def _calc_response_usage(
self, model: str, credentials: dict, tokens: int
) -> EmbeddingUsage:
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.yaml b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.yaml
index 303a5491..c4e526ac 100644
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.yaml
+++ b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai.yaml
@@ -29,3 +29,12 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 APIKey
en_US: Enter your APIKey
+ - variable: api_base
+ label:
+ zh_Hans: API Base
+ en_US: API Base
+ type: text-input
+ required: false
+ placeholder:
+ zh_Hans: 在此输入您的 API Base
+ en_US: Enter your API Base
\ No newline at end of file
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py
deleted file mode 100644
index bf9b093c..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from .__version__ import __version__
-from ._client import ZhipuAI
-from .core._errors import (
- APIAuthenticationError,
- APIInternalError,
- APIReachLimitError,
- APIRequestFailedError,
- APIResponseError,
- APIResponseValidationError,
- APIServerFlowExceedError,
- APIStatusError,
- APITimeoutError,
- ZhipuAIError,
-)
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py
deleted file mode 100644
index 659f38d7..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py
+++ /dev/null
@@ -1 +0,0 @@
-__version__ = "v2.0.1"
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py
deleted file mode 100644
index 27173a4d..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py
+++ /dev/null
@@ -1,75 +0,0 @@
-from __future__ import annotations
-
-import os
-from collections.abc import Mapping
-from typing import Union
-
-import httpx
-from httpx import Timeout
-from typing_extensions import override
-
-from . import api_resource
-from .core import _jwt_token
-from .core._base_type import NOT_GIVEN, NotGiven
-from .core._errors import ZhipuAIError
-from .core._http_client import ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient
-
-
-class ZhipuAI(HttpClient):
- chat: api_resource.chat
- api_key: str
-
- def __init__(
- self,
- *,
- api_key: str | None = None,
- base_url: str | httpx.URL | None = None,
- timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
- max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
- http_client: httpx.Client | None = None,
- custom_headers: Mapping[str, str] | None = None,
- ) -> None:
- # if api_key is None:
- # api_key = os.environ.get("ZHIPUAI_API_KEY")
- if api_key is None:
- raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
- self.api_key = api_key
-
- if base_url is None:
- base_url = os.environ.get("ZHIPUAI_BASE_URL")
- if base_url is None:
- base_url = "https://open.bigmodel.cn/api/paas/v4"
- from .__version__ import __version__
-
- super().__init__(
- version=__version__,
- base_url=base_url,
- timeout=timeout,
- custom_httpx_client=http_client,
- custom_headers=custom_headers,
- )
- self.chat = api_resource.chat.Chat(self)
- self.images = api_resource.images.Images(self)
- self.embeddings = api_resource.embeddings.Embeddings(self)
- self.files = api_resource.files.Files(self)
- self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
-
- @property
- @override
- def _auth_headers(self) -> dict[str, str]:
- api_key = self.api_key
- return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
-
- def __del__(self) -> None:
- if (
- not hasattr(self, "_has_custom_http_client")
- or not hasattr(self, "close")
- or not hasattr(self, "_client")
- ):
- # if the '__init__' method raised an error, self would not have client attr
- return
-
- if self._has_custom_http_client:
- return
-
- self.close()
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py
deleted file mode 100644
index 0a90e21e..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .chat import chat
-from .embeddings import Embeddings
-from .files import Files
-from .fine_tuning import fine_tuning
-from .images import Images
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py
deleted file mode 100644
index ce5d737e..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py
+++ /dev/null
@@ -1,82 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Literal, Optional, Union
-
-import httpx
-
-from ...core._base_api import BaseAPI
-from ...core._base_type import NOT_GIVEN, Headers, NotGiven
-from ...core._http_client import make_user_request_input
-from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus
-
-if TYPE_CHECKING:
- from ..._client import ZhipuAI
-
-
-class AsyncCompletions(BaseAPI):
- def __init__(self, client: ZhipuAI) -> None:
- super().__init__(client)
-
- def create(
- self,
- *,
- model: str,
- request_id: Optional[str] | NotGiven = NOT_GIVEN,
- do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
- temperature: Optional[float] | NotGiven = NOT_GIVEN,
- top_p: Optional[float] | NotGiven = NOT_GIVEN,
- max_tokens: int | NotGiven = NOT_GIVEN,
- seed: int | NotGiven = NOT_GIVEN,
- messages: Union[str, list[str], list[int], list[list[int]], None],
- stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
- sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
- tools: Optional[object] | NotGiven = NOT_GIVEN,
- tool_choice: str | NotGiven = NOT_GIVEN,
- extra_headers: Headers | None = None,
- disable_strict_validation: Optional[bool] | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> AsyncTaskStatus:
- _cast_type = AsyncTaskStatus
-
- if disable_strict_validation:
- _cast_type = object
- return self._post(
- "/async/chat/completions",
- body={
- "model": model,
- "request_id": request_id,
- "temperature": temperature,
- "top_p": top_p,
- "do_sample": do_sample,
- "max_tokens": max_tokens,
- "seed": seed,
- "messages": messages,
- "stop": stop,
- "sensitive_word_check": sensitive_word_check,
- "tools": tools,
- "tool_choice": tool_choice,
- },
- options=make_user_request_input(
- extra_headers=extra_headers, timeout=timeout
- ),
- cast_type=_cast_type,
- enable_stream=False,
- )
-
- def retrieve_completion_result(
- self,
- id: str,
- extra_headers: Headers | None = None,
- disable_strict_validation: Optional[bool] | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> Union[AsyncCompletion, AsyncTaskStatus]:
- _cast_type = Union[AsyncCompletion, AsyncTaskStatus]
- if disable_strict_validation:
- _cast_type = object
- return self._get(
- path=f"/async-result/{id}",
- cast_type=_cast_type,
- options=make_user_request_input(
- extra_headers=extra_headers, timeout=timeout
- ),
- )
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py
deleted file mode 100644
index 92362fc5..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from typing import TYPE_CHECKING
-
-from ...core._base_api import BaseAPI
-from .async_completions import AsyncCompletions
-from .completions import Completions
-
-if TYPE_CHECKING:
- from ..._client import ZhipuAI
-
-
-class Chat(BaseAPI):
- completions: Completions
-
- def __init__(self, client: "ZhipuAI") -> None:
- super().__init__(client)
- self.completions = Completions(client)
- self.asyncCompletions = AsyncCompletions(client)
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py
deleted file mode 100644
index ec29f338..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py
+++ /dev/null
@@ -1,70 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Literal, Optional, Union
-
-import httpx
-
-from ...core._base_api import BaseAPI
-from ...core._base_type import NOT_GIVEN, Headers, NotGiven
-from ...core._http_client import make_user_request_input
-from ...core._sse_client import StreamResponse
-from ...types.chat.chat_completion import Completion
-from ...types.chat.chat_completion_chunk import ChatCompletionChunk
-
-if TYPE_CHECKING:
- from ..._client import ZhipuAI
-
-
-class Completions(BaseAPI):
- def __init__(self, client: ZhipuAI) -> None:
- super().__init__(client)
-
- def create(
- self,
- *,
- model: str,
- request_id: Optional[str] | NotGiven = NOT_GIVEN,
- do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
- stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
- temperature: Optional[float] | NotGiven = NOT_GIVEN,
- top_p: Optional[float] | NotGiven = NOT_GIVEN,
- max_tokens: int | NotGiven = NOT_GIVEN,
- seed: int | NotGiven = NOT_GIVEN,
- messages: Union[str, list[str], list[int], object, None],
- stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
- sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
- tools: Optional[object] | NotGiven = NOT_GIVEN,
- tool_choice: str | NotGiven = NOT_GIVEN,
- extra_headers: Headers | None = None,
- disable_strict_validation: Optional[bool] | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> Completion | StreamResponse[ChatCompletionChunk]:
- _cast_type = Completion
- _stream_cls = StreamResponse[ChatCompletionChunk]
- if disable_strict_validation:
- _cast_type = object
- _stream_cls = StreamResponse[object]
- return self._post(
- "/chat/completions",
- body={
- "model": model,
- "request_id": request_id,
- "temperature": temperature,
- "top_p": top_p,
- "do_sample": do_sample,
- "max_tokens": max_tokens,
- "seed": seed,
- "messages": messages,
- "stop": stop,
- "sensitive_word_check": sensitive_word_check,
- "stream": stream,
- "tools": tools,
- "tool_choice": tool_choice,
- },
- options=make_user_request_input(
- extra_headers=extra_headers,
- ),
- cast_type=_cast_type,
- enable_stream=stream or False,
- stream_cls=_stream_cls,
- )
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py
deleted file mode 100644
index 4da0276a..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Optional, Union
-
-import httpx
-
-from ..core._base_api import BaseAPI
-from ..core._base_type import NOT_GIVEN, Headers, NotGiven
-from ..core._http_client import make_user_request_input
-from ..types.embeddings import EmbeddingsResponded
-
-if TYPE_CHECKING:
- from .._client import ZhipuAI
-
-
-class Embeddings(BaseAPI):
- def __init__(self, client: ZhipuAI) -> None:
- super().__init__(client)
-
- def create(
- self,
- *,
- input: Union[str, list[str], list[int], list[list[int]]],
- model: Union[str],
- encoding_format: str | NotGiven = NOT_GIVEN,
- user: str | NotGiven = NOT_GIVEN,
- sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
- extra_headers: Headers | None = None,
- disable_strict_validation: Optional[bool] | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> EmbeddingsResponded:
- _cast_type = EmbeddingsResponded
- if disable_strict_validation:
- _cast_type = object
- return self._post(
- "/embeddings",
- body={
- "input": input,
- "model": model,
- "encoding_format": encoding_format,
- "user": user,
- "sensitive_word_check": sensitive_word_check,
- },
- options=make_user_request_input(
- extra_headers=extra_headers, timeout=timeout
- ),
- cast_type=_cast_type,
- enable_stream=False,
- )
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py
deleted file mode 100644
index f48dc4ff..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py
+++ /dev/null
@@ -1,75 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-import httpx
-
-from ..core._base_api import BaseAPI
-from ..core._base_type import NOT_GIVEN, FileTypes, Headers, NotGiven
-from ..core._files import is_file_content
-from ..core._http_client import make_user_request_input
-from ..types.file_object import FileObject, ListOfFileObject
-
-if TYPE_CHECKING:
- from .._client import ZhipuAI
-
-__all__ = ["Files"]
-
-
-class Files(BaseAPI):
- def __init__(self, client: ZhipuAI) -> None:
- super().__init__(client)
-
- def create(
- self,
- *,
- file: FileTypes,
- purpose: str,
- extra_headers: Headers | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> FileObject:
- if not is_file_content(file):
- prefix = f"Expected file input `{file!r}`"
- raise RuntimeError(
- f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
- ) from None
- files = [("file", file)]
-
- extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
-
- return self._post(
- "/files",
- body={
- "purpose": purpose,
- },
- files=files,
- options=make_user_request_input(
- extra_headers=extra_headers, timeout=timeout
- ),
- cast_type=FileObject,
- )
-
- def list(
- self,
- *,
- purpose: str | NotGiven = NOT_GIVEN,
- limit: int | NotGiven = NOT_GIVEN,
- after: str | NotGiven = NOT_GIVEN,
- order: str | NotGiven = NOT_GIVEN,
- extra_headers: Headers | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> ListOfFileObject:
- return self._get(
- "/files",
- cast_type=ListOfFileObject,
- options=make_user_request_input(
- extra_headers=extra_headers,
- timeout=timeout,
- query={
- "purpose": purpose,
- "limit": limit,
- "after": after,
- "order": order,
- },
- ),
- )
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py
deleted file mode 100644
index dc30bd33..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from typing import TYPE_CHECKING
-
-from ...core._base_api import BaseAPI
-from .jobs import Jobs
-
-if TYPE_CHECKING:
- from ..._client import ZhipuAI
-
-
-class FineTuning(BaseAPI):
- jobs: Jobs
-
- def __init__(self, client: "ZhipuAI") -> None:
- super().__init__(client)
- self.jobs = Jobs(client)
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py
deleted file mode 100644
index ecdf455e..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py
+++ /dev/null
@@ -1,111 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Optional
-
-import httpx
-
-from ...core._base_api import BaseAPI
-from ...core._base_type import NOT_GIVEN, Headers, NotGiven
-from ...core._http_client import make_user_request_input
-from ...types.fine_tuning import (
- FineTuningJob,
- FineTuningJobEvent,
- ListOfFineTuningJob,
- job_create_params,
-)
-
-if TYPE_CHECKING:
- from ..._client import ZhipuAI
-
-__all__ = ["Jobs"]
-
-
-class Jobs(BaseAPI):
- def __init__(self, client: ZhipuAI) -> None:
- super().__init__(client)
-
- def create(
- self,
- *,
- model: str,
- training_file: str,
- hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
- suffix: Optional[str] | NotGiven = NOT_GIVEN,
- request_id: Optional[str] | NotGiven = NOT_GIVEN,
- validation_file: Optional[str] | NotGiven = NOT_GIVEN,
- extra_headers: Headers | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> FineTuningJob:
- return self._post(
- "/fine_tuning/jobs",
- body={
- "model": model,
- "training_file": training_file,
- "hyperparameters": hyperparameters,
- "suffix": suffix,
- "validation_file": validation_file,
- "request_id": request_id,
- },
- options=make_user_request_input(
- extra_headers=extra_headers, timeout=timeout
- ),
- cast_type=FineTuningJob,
- )
-
- def retrieve(
- self,
- fine_tuning_job_id: str,
- *,
- extra_headers: Headers | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> FineTuningJob:
- return self._get(
- f"/fine_tuning/jobs/{fine_tuning_job_id}",
- options=make_user_request_input(
- extra_headers=extra_headers, timeout=timeout
- ),
- cast_type=FineTuningJob,
- )
-
- def list(
- self,
- *,
- after: str | NotGiven = NOT_GIVEN,
- limit: int | NotGiven = NOT_GIVEN,
- extra_headers: Headers | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> ListOfFineTuningJob:
- return self._get(
- "/fine_tuning/jobs",
- cast_type=ListOfFineTuningJob,
- options=make_user_request_input(
- extra_headers=extra_headers,
- timeout=timeout,
- query={
- "after": after,
- "limit": limit,
- },
- ),
- )
-
- def list_events(
- self,
- fine_tuning_job_id: str,
- *,
- after: str | NotGiven = NOT_GIVEN,
- limit: int | NotGiven = NOT_GIVEN,
- extra_headers: Headers | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> FineTuningJobEvent:
- return self._get(
- f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
- cast_type=FineTuningJobEvent,
- options=make_user_request_input(
- extra_headers=extra_headers,
- timeout=timeout,
- query={
- "after": after,
- "limit": limit,
- },
- ),
- )
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py
deleted file mode 100644
index 63325ce5..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Optional
-
-import httpx
-
-from ..core._base_api import BaseAPI
-from ..core._base_type import NOT_GIVEN, Headers, NotGiven
-from ..core._http_client import make_user_request_input
-from ..types.image import ImagesResponded
-
-if TYPE_CHECKING:
- from .._client import ZhipuAI
-
-
-class Images(BaseAPI):
- def __init__(self, client: ZhipuAI) -> None:
- super().__init__(client)
-
- def generations(
- self,
- *,
- prompt: str,
- model: str | NotGiven = NOT_GIVEN,
- n: Optional[int] | NotGiven = NOT_GIVEN,
- quality: Optional[str] | NotGiven = NOT_GIVEN,
- response_format: Optional[str] | NotGiven = NOT_GIVEN,
- size: Optional[str] | NotGiven = NOT_GIVEN,
- style: Optional[str] | NotGiven = NOT_GIVEN,
- user: str | NotGiven = NOT_GIVEN,
- extra_headers: Headers | None = None,
- disable_strict_validation: Optional[bool] | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> ImagesResponded:
- _cast_type = ImagesResponded
- if disable_strict_validation:
- _cast_type = object
- return self._post(
- "/images/generations",
- body={
- "prompt": prompt,
- "model": model,
- "n": n,
- "quality": quality,
- "response_format": response_format,
- "size": size,
- "style": style,
- "user": user,
- },
- options=make_user_request_input(
- extra_headers=extra_headers, timeout=timeout
- ),
- cast_type=_cast_type,
- enable_stream=False,
- )
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py
deleted file mode 100644
index 10b46ff8..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from .._client import ZhipuAI
-
-
-class BaseAPI:
- _client: ZhipuAI
-
- def __init__(self, client: ZhipuAI) -> None:
- self._client = client
- self._delete = client.delete
- self._get = client.get
- self._post = client.post
- self._put = client.put
- self._patch = client.patch
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py
deleted file mode 100644
index 40630556..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from __future__ import annotations
-
-from collections.abc import Mapping, Sequence
-from os import PathLike
-from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, Union
-
-import pydantic
-from typing_extensions import override
-
-Query = Mapping[str, object]
-Body = object
-AnyMapping = Mapping[str, object]
-PrimitiveData = Union[str, int, float, bool, None]
-Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"]
-ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
-_T = TypeVar("_T")
-
-if TYPE_CHECKING:
- NoneType: type[None]
-else:
- NoneType = type(None)
-
-
-# Sentinel class used until PEP 0661 is accepted
-class NotGiven(pydantic.BaseModel):
- """
- A sentinel singleton class used to distinguish omitted keyword arguments
- from those passed in with the value None (which may have different behavior).
-
- For example:
-
- ```py
- def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
-
- get(timeout=1) # 1s timeout
- get(timeout=None) # No timeout
- get() # Default timeout behavior, which may not be statically known at the method definition.
- ```
- """
-
- def __bool__(self) -> Literal[False]:
- return False
-
- @override
- def __repr__(self) -> str:
- return "NOT_GIVEN"
-
-
-NotGivenOr = Union[_T, NotGiven]
-NOT_GIVEN = NotGiven()
-
-
-class Omit(pydantic.BaseModel):
- """In certain situations you need to be able to represent a case where a default value has
- to be explicitly removed and `None` is not an appropriate substitute, for example:
-
- ```py
- # as the default `Content-Type` header is `application/json` that will be sent
- client.post('/upload/files', files={'file': b'my raw file content'})
-
- # you can't explicitly override the header as it has to be dynamically generated
- # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
- client.post(..., headers={'Content-Type': 'multipart/form-data'})
-
- # instead you can remove the default `application/json` header by passing Omit
- client.post(..., headers={'Content-Type': Omit()})
- ```
- """
-
- def __bool__(self) -> Literal[False]:
- return False
-
-
-Headers = Mapping[str, Union[str, Omit]]
-
-ResponseT = TypeVar(
- "ResponseT",
- bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
-)
-
-# for user input files
-if TYPE_CHECKING:
- FileContent = Union[IO[bytes], bytes, PathLike[str]]
-else:
- FileContent = Union[IO[bytes], bytes, PathLike]
-
-FileTypes = Union[
- FileContent, # file content
- tuple[str, FileContent], # (filename, file)
- tuple[str, FileContent, str], # (filename, file , content_type)
- tuple[
- str, FileContent, str, Mapping[str, str]
- ], # (filename, file , content_type, headers)
-]
-
-RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]]
-
-# for httpx client supported files
-
-HttpxFileContent = Union[bytes, IO[bytes]]
-HttpxFileTypes = Union[
- FileContent, # file content
- tuple[str, HttpxFileContent], # (filename, file)
- tuple[str, HttpxFileContent, str], # (filename, file , content_type)
- tuple[
- str, HttpxFileContent, str, Mapping[str, str]
- ], # (filename, file , content_type, headers)
-]
-
-HttpxRequestFiles = Union[
- Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]
-]
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py
deleted file mode 100644
index 1800a3a3..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from __future__ import annotations
-
-import httpx
-
-__all__ = [
- "ZhipuAIError",
- "APIStatusError",
- "APIRequestFailedError",
- "APIAuthenticationError",
- "APIReachLimitError",
- "APIInternalError",
- "APIServerFlowExceedError",
- "APIResponseError",
- "APIResponseValidationError",
- "APITimeoutError",
-]
-
-
-class ZhipuAIError(Exception):
- def __init__(
- self,
- message: str,
- ) -> None:
- super().__init__(message)
-
-
-class APIStatusError(Exception):
- response: httpx.Response
- status_code: int
-
- def __init__(self, message: str, *, response: httpx.Response) -> None:
- super().__init__(message)
- self.response = response
- self.status_code = response.status_code
-
-
-class APIRequestFailedError(APIStatusError):
- ...
-
-
-class APIAuthenticationError(APIStatusError):
- ...
-
-
-class APIReachLimitError(APIStatusError):
- ...
-
-
-class APIInternalError(APIStatusError):
- ...
-
-
-class APIServerFlowExceedError(APIStatusError):
- ...
-
-
-class APIResponseError(Exception):
- message: str
- request: httpx.Request
- json_data: object
-
- def __init__(self, message: str, request: httpx.Request, json_data: object):
- self.message = message
- self.request = request
- self.json_data = json_data
- super().__init__(message)
-
-
-class APIResponseValidationError(APIResponseError):
- status_code: int
- response: httpx.Response
-
- def __init__(
- self,
- response: httpx.Response,
- json_data: object | None,
- *,
- message: str | None = None,
- ) -> None:
- super().__init__(
- message=message or "Data returned by API invalid for expected schema.",
- request=response.request,
- json_data=json_data,
- )
- self.response = response
- self.status_code = response.status_code
-
-
-class APITimeoutError(Exception):
- request: httpx.Request
-
- def __init__(self, request: httpx.Request):
- self.request = request
- super().__init__("Request Timeout")
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py
deleted file mode 100644
index e7fa1ad2..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from __future__ import annotations
-
-import io
-import os
-from collections.abc import Mapping, Sequence
-from pathlib import Path
-
-from ._base_type import FileTypes, HttpxFileTypes, HttpxRequestFiles, RequestFiles
-
-
-def is_file_content(obj: object) -> bool:
- return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike)
-
-
-def _transform_file(file: FileTypes) -> HttpxFileTypes:
- if is_file_content(file):
- if isinstance(file, os.PathLike):
- path = Path(file)
- return path.name, path.read_bytes()
- else:
- return file
- if isinstance(file, tuple):
- if isinstance(file[1], os.PathLike):
- return (file[0], Path(file[1]).read_bytes(), *file[2:])
- else:
- return (file[0], file[1], *file[2:])
- else:
- raise TypeError(
- f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type"
- )
-
-
-def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
- if files is None:
- return None
-
- if isinstance(files, Mapping):
- files = {key: _transform_file(file) for key, file in files.items()}
- elif isinstance(files, Sequence):
- files = [(key, _transform_file(file)) for key, file in files]
- else:
- raise TypeError(
- f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence"
- )
- return files
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py
deleted file mode 100644
index 9e968d52..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py
+++ /dev/null
@@ -1,401 +0,0 @@
-from __future__ import annotations
-
-import inspect
-from collections.abc import Mapping
-from typing import Any, Union, cast
-
-import httpx
-import pydantic
-from httpx import URL, Timeout
-
-from . import _errors
-from ._base_type import (
- NOT_GIVEN,
- Body,
- Data,
- Headers,
- NotGiven,
- Query,
- RequestFiles,
- ResponseT,
-)
-from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
-from ._files import make_httpx_files
-from ._request_opt import ClientRequestParam, UserRequestInput
-from ._response import HttpResponse
-from ._sse_client import StreamResponse
-from ._utils import flatten
-
-headers = {
- "Accept": "application/json",
- "Content-Type": "application/json; charset=UTF-8",
-}
-
-
-def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
- merged = {**map1, **map2}
- return {key: val for key, val in merged.items() if val is not None}
-
-
-from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
-
-ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
-ZHIPUAI_DEFAULT_MAX_RETRIES = 3
-ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=5, max_keepalive_connections=5)
-
-
-class HttpClient:
- _client: httpx.Client
- _version: str
- _base_url: URL
-
- timeout: Union[float, Timeout, None]
- _limits: httpx.Limits
- _has_custom_http_client: bool
- _default_stream_cls: type[StreamResponse[Any]] | None = None
-
- def __init__(
- self,
- *,
- version: str,
- base_url: URL,
- timeout: Union[float, Timeout, None],
- custom_httpx_client: httpx.Client | None = None,
- custom_headers: Mapping[str, str] | None = None,
- ) -> None:
- if timeout is None or isinstance(timeout, NotGiven):
- if (
- custom_httpx_client
- and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT
- ):
- timeout = custom_httpx_client.timeout
- else:
- timeout = ZHIPUAI_DEFAULT_TIMEOUT
- self.timeout = cast(Timeout, timeout)
- self._has_custom_http_client = bool(custom_httpx_client)
- self._client = custom_httpx_client or httpx.Client(
- base_url=base_url,
- timeout=self.timeout,
- limits=ZHIPUAI_DEFAULT_LIMITS,
- )
- self._version = version
- url = URL(url=base_url)
- if not url.raw_path.endswith(b"/"):
- url = url.copy_with(raw_path=url.raw_path + b"/")
- self._base_url = url
- self._custom_headers = custom_headers or {}
-
- def _prepare_url(self, url: str) -> URL:
- sub_url = URL(url)
- if sub_url.is_relative_url:
- request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/")
- return self._base_url.copy_with(raw_path=request_raw_url)
-
- return sub_url
-
- @property
- def _default_headers(self):
- return {
- "Accept": "application/json",
- "Content-Type": "application/json; charset=UTF-8",
- "ZhipuAI-SDK-Ver": self._version,
- "source_type": "zhipu-sdk-python",
- "x-request-sdk": "zhipu-sdk-python",
- **self._auth_headers,
- **self._custom_headers,
- }
-
- @property
- def _auth_headers(self):
- return {}
-
- def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers:
- custom_headers = request_param.headers or {}
- headers_dict = _merge_map(self._default_headers, custom_headers)
-
- httpx_headers = httpx.Headers(headers_dict)
-
- return httpx_headers
-
- def _prepare_request(self, request_param: ClientRequestParam) -> httpx.Request:
- kwargs: dict[str, Any] = {}
- json_data = request_param.json_data
- headers = self._prepare_headers(request_param)
- url = self._prepare_url(request_param.url)
- json_data = request_param.json_data
- if headers.get("Content-Type") == "multipart/form-data":
- headers.pop("Content-Type")
-
- if json_data:
- kwargs["data"] = self._make_multipartform(json_data)
-
- return self._client.build_request(
- headers=headers,
- timeout=self.timeout
- if isinstance(request_param.timeout, NotGiven)
- else request_param.timeout,
- method=request_param.method,
- url=url,
- json=json_data,
- files=request_param.files,
- params=request_param.params,
- **kwargs,
- )
-
- def _object_to_formfata(
- self, key: str, value: Data | Mapping[object, object]
- ) -> list[tuple[str, str]]:
- items = []
-
- if isinstance(value, Mapping):
- for k, v in value.items():
- items.extend(self._object_to_formfata(f"{key}[{k}]", v))
- return items
- if isinstance(value, list | tuple):
- for v in value:
- items.extend(self._object_to_formfata(key + "[]", v))
- return items
-
- def _primitive_value_to_str(val) -> str:
- # copied from httpx
- if val is True:
- return "true"
- elif val is False:
- return "false"
- elif val is None:
- return ""
- return str(val)
-
- str_data = _primitive_value_to_str(value)
-
- if not str_data:
- return []
- return [(key, str_data)]
-
- def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
- items = flatten([self._object_to_formfata(k, v) for k, v in data.items()])
-
- serialized: dict[str, object] = {}
- for key, value in items:
- if key in serialized:
- raise ValueError(f"存在重复的键: {key};")
- serialized[key] = value
- return serialized
-
- def _parse_response(
- self,
- *,
- cast_type: type[ResponseT],
- response: httpx.Response,
- enable_stream: bool,
- request_param: ClientRequestParam,
- stream_cls: type[StreamResponse[Any]] | None = None,
- ) -> HttpResponse:
- http_response = HttpResponse(
- raw_response=response,
- cast_type=cast_type,
- client=self,
- enable_stream=enable_stream,
- stream_cls=stream_cls,
- )
- return http_response.parse()
-
- def _process_response_data(
- self,
- *,
- data: object,
- cast_type: type[ResponseT],
- response: httpx.Response,
- ) -> ResponseT:
- if data is None:
- return cast(ResponseT, None)
-
- try:
- if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel):
- return cast(ResponseT, cast_type.validate(data))
-
- return cast(
- ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data)
- )
- except pydantic.ValidationError as err:
- raise APIResponseValidationError(response=response, json_data=data) from err
-
- def is_closed(self) -> bool:
- return self._client.is_closed
-
- def close(self):
- self._client.close()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
-
- def request(
- self,
- *,
- cast_type: type[ResponseT],
- params: ClientRequestParam,
- enable_stream: bool = False,
- stream_cls: type[StreamResponse[Any]] | None = None,
- ) -> ResponseT | StreamResponse:
- request = self._prepare_request(params)
-
- try:
- response = self._client.send(
- request,
- stream=enable_stream,
- )
- response.raise_for_status()
- except httpx.TimeoutException as err:
- raise APITimeoutError(request=request) from err
- except httpx.HTTPStatusError as err:
- err.response.read()
- # raise err
- raise self._make_status_error(err.response) from None
-
- except Exception as err:
- raise err
-
- return self._parse_response(
- cast_type=cast_type,
- request_param=params,
- response=response,
- enable_stream=enable_stream,
- stream_cls=stream_cls,
- )
-
- def get(
- self,
- path: str,
- *,
- cast_type: type[ResponseT],
- options: UserRequestInput = {},
- enable_stream: bool = False,
- ) -> ResponseT | StreamResponse:
- opts = ClientRequestParam.construct(method="get", url=path, **options)
- return self.request(
- cast_type=cast_type, params=opts, enable_stream=enable_stream
- )
-
- def post(
- self,
- path: str,
- *,
- body: Body | None = None,
- cast_type: type[ResponseT],
- options: UserRequestInput = {},
- files: RequestFiles | None = None,
- enable_stream: bool = False,
- stream_cls: type[StreamResponse[Any]] | None = None,
- ) -> ResponseT | StreamResponse:
- opts = ClientRequestParam.construct(
- method="post",
- json_data=body,
- files=make_httpx_files(files),
- url=path,
- **options,
- )
-
- return self.request(
- cast_type=cast_type,
- params=opts,
- enable_stream=enable_stream,
- stream_cls=stream_cls,
- )
-
- def patch(
- self,
- path: str,
- *,
- body: Body | None = None,
- cast_type: type[ResponseT],
- options: UserRequestInput = {},
- ) -> ResponseT:
- opts = ClientRequestParam.construct(
- method="patch", url=path, json_data=body, **options
- )
-
- return self.request(
- cast_type=cast_type,
- params=opts,
- )
-
- def put(
- self,
- path: str,
- *,
- body: Body | None = None,
- cast_type: type[ResponseT],
- options: UserRequestInput = {},
- files: RequestFiles | None = None,
- ) -> ResponseT | StreamResponse:
- opts = ClientRequestParam.construct(
- method="put",
- url=path,
- json_data=body,
- files=make_httpx_files(files),
- **options,
- )
-
- return self.request(
- cast_type=cast_type,
- params=opts,
- )
-
- def delete(
- self,
- path: str,
- *,
- body: Body | None = None,
- cast_type: type[ResponseT],
- options: UserRequestInput = {},
- ) -> ResponseT | StreamResponse:
- opts = ClientRequestParam.construct(
- method="delete", url=path, json_data=body, **options
- )
-
- return self.request(
- cast_type=cast_type,
- params=opts,
- )
-
- def _make_status_error(self, response) -> APIStatusError:
- response_text = response.text.strip()
- status_code = response.status_code
- error_msg = f"Error code: {status_code}, with error text {response_text}"
-
- if status_code == 400:
- return _errors.APIRequestFailedError(message=error_msg, response=response)
- elif status_code == 401:
- return _errors.APIAuthenticationError(message=error_msg, response=response)
- elif status_code == 429:
- return _errors.APIReachLimitError(message=error_msg, response=response)
- elif status_code == 500:
- return _errors.APIInternalError(message=error_msg, response=response)
- elif status_code == 503:
- return _errors.APIServerFlowExceedError(
- message=error_msg, response=response
- )
- return APIStatusError(message=error_msg, response=response)
-
-
-def make_user_request_input(
- max_retries: int | None = None,
- timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
- extra_headers: Headers = None,
- query: Query | None = None,
-) -> UserRequestInput:
- options: UserRequestInput = {}
-
- if extra_headers is not None:
- options["headers"] = extra_headers
- if max_retries is not None:
- options["max_retries"] = max_retries
- if not isinstance(timeout, NotGiven):
- options["timeout"] = timeout
- if query is not None:
- options["params"] = query
-
- return options
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py
deleted file mode 100644
index b0a91d04..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import time
-
-import cachetools.func
-import jwt
-
-API_TOKEN_TTL_SECONDS = 3 * 60
-
-CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
-
-
-@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
-def generate_token(apikey: str):
- try:
- api_key, secret = apikey.split(".")
- except Exception as e:
- raise Exception("invalid api_key", e)
-
- payload = {
- "api_key": api_key,
- "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
- "timestamp": int(round(time.time() * 1000)),
- }
- ret = jwt.encode(
- payload,
- secret,
- algorithm="HS256",
- headers={"alg": "HS256", "sign_type": "SIGN"},
- )
- return ret
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py
deleted file mode 100644
index 7bd5b3e4..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from __future__ import annotations
-
-from typing import Any, ClassVar, Union
-
-from httpx import Timeout
-from pydantic import ConfigDict
-from typing_extensions import TypedDict, Unpack
-
-from ._base_type import Body, Headers, HttpxRequestFiles, NotGiven, Query
-from ._utils import remove_notgiven_indict
-
-
-class UserRequestInput(TypedDict, total=False):
- max_retries: int
- timeout: float | Timeout | None
- headers: Headers
- params: Query | None
-
-
-class ClientRequestParam:
- method: str
- url: str
- max_retries: Union[int, NotGiven] = NotGiven()
- timeout: Union[float, NotGiven] = NotGiven()
- headers: Union[Headers, NotGiven] = NotGiven()
- json_data: Union[Body, None] = None
- files: Union[HttpxRequestFiles, None] = None
- params: Query = {}
- model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
-
- def get_max_retries(self, max_retries) -> int:
- if isinstance(self.max_retries, NotGiven):
- return max_retries
- return self.max_retries
-
- @classmethod
- def construct( # type: ignore
- cls,
- _fields_set: set[str] | None = None,
- **values: Unpack[UserRequestInput],
- ) -> ClientRequestParam:
- kwargs: dict[str, Any] = {
- key: remove_notgiven_indict(value) for key, value in values.items()
- }
- client = cls()
- client.__dict__.update(kwargs)
-
- return client
-
- model_construct = construct
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py
deleted file mode 100644
index 7addfd8c..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py
+++ /dev/null
@@ -1,123 +0,0 @@
-from __future__ import annotations
-
-import datetime
-from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args, get_origin
-
-import httpx
-import pydantic
-from typing_extensions import ParamSpec
-
-from ._base_type import NoneType
-from ._sse_client import StreamResponse
-
-if TYPE_CHECKING:
- from ._http_client import HttpClient
-
-P = ParamSpec("P")
-R = TypeVar("R")
-
-
-class HttpResponse(Generic[R]):
- _cast_type: type[R]
- _client: HttpClient
- _parsed: R | None
- _enable_stream: bool
- _stream_cls: type[StreamResponse[Any]]
- http_response: httpx.Response
-
- def __init__(
- self,
- *,
- raw_response: httpx.Response,
- cast_type: type[R],
- client: HttpClient,
- enable_stream: bool = False,
- stream_cls: type[StreamResponse[Any]] | None = None,
- ) -> None:
- self._cast_type = cast_type
- self._client = client
- self._parsed = None
- self._stream_cls = stream_cls
- self._enable_stream = enable_stream
- self.http_response = raw_response
-
- def parse(self) -> R:
- self._parsed = self._parse()
- return self._parsed
-
- def _parse(self) -> R:
- if self._enable_stream:
- self._parsed = cast(
- R,
- self._stream_cls(
- cast_type=cast(type, get_args(self._stream_cls)[0]),
- response=self.http_response,
- client=self._client,
- ),
- )
- return self._parsed
- cast_type = self._cast_type
- if cast_type is NoneType:
- return cast(R, None)
- http_response = self.http_response
- if cast_type == str:
- return cast(R, http_response.text)
-
- content_type, *_ = http_response.headers.get(
- "content-type", "application/json"
- ).split(";")
- origin = get_origin(cast_type) or cast_type
- if content_type != "application/json":
- if issubclass(origin, pydantic.BaseModel):
- data = http_response.json()
- return self._client._process_response_data(
- data=data,
- cast_type=cast_type, # type: ignore
- response=http_response,
- )
-
- return http_response.text
-
- data = http_response.json()
-
- return self._client._process_response_data(
- data=data,
- cast_type=cast_type, # type: ignore
- response=http_response,
- )
-
- @property
- def headers(self) -> httpx.Headers:
- return self.http_response.headers
-
- @property
- def http_request(self) -> httpx.Request:
- return self.http_response.request
-
- @property
- def status_code(self) -> int:
- return self.http_response.status_code
-
- @property
- def url(self) -> httpx.URL:
- return self.http_response.url
-
- @property
- def method(self) -> str:
- return self.http_request.method
-
- @property
- def content(self) -> bytes:
- return self.http_response.content
-
- @property
- def text(self) -> str:
- return self.http_response.text
-
- @property
- def http_version(self) -> str:
- return self.http_response.http_version
-
- @property
- def elapsed(self) -> datetime.timedelta:
- return self.http_response.elapsed
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py
deleted file mode 100644
index ce3b6df6..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py
+++ /dev/null
@@ -1,155 +0,0 @@
-from __future__ import annotations
-
-import json
-from collections.abc import Iterator, Mapping
-from typing import TYPE_CHECKING, Generic
-
-import httpx
-
-from ._base_type import ResponseT
-from ._errors import APIResponseError
-
-_FIELD_SEPARATOR = ":"
-
-if TYPE_CHECKING:
- from ._http_client import HttpClient
-
-
-class StreamResponse(Generic[ResponseT]):
- response: httpx.Response
- _cast_type: type[ResponseT]
-
- def __init__(
- self,
- *,
- cast_type: type[ResponseT],
- response: httpx.Response,
- client: HttpClient,
- ) -> None:
- self.response = response
- self._cast_type = cast_type
- self._data_process_func = client._process_response_data
- self._stream_chunks = self.__stream__()
-
- def __next__(self) -> ResponseT:
- return self._stream_chunks.__next__()
-
- def __iter__(self) -> Iterator[ResponseT]:
- yield from self._stream_chunks
-
- def __stream__(self) -> Iterator[ResponseT]:
- sse_line_parser = SSELineParser()
- iterator = sse_line_parser.iter_lines(self.response.iter_lines())
-
- for sse in iterator:
- if sse.data.startswith("[DONE]"):
- break
-
- if sse.event is None:
- data = sse.json_data()
- if isinstance(data, Mapping) and data.get("error"):
- raise APIResponseError(
- message="An error occurred during streaming",
- request=self.response.request,
- json_data=data["error"],
- )
-
- yield self._data_process_func(
- data=data, cast_type=self._cast_type, response=self.response
- )
- for sse in iterator:
- pass
-
-
-class Event:
- def __init__(
- self,
- event: str | None = None,
- data: str | None = None,
- id: str | None = None,
- retry: int | None = None,
- ):
- self._event = event
- self._data = data
- self._id = id
- self._retry = retry
-
- def __repr__(self):
- data_len = len(self._data) if self._data else 0
- return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
-
- @property
- def event(self):
- return self._event
-
- @property
- def data(self):
- return self._data
-
- def json_data(self):
- return json.loads(self._data)
-
- @property
- def id(self):
- return self._id
-
- @property
- def retry(self):
- return self._retry
-
-
-class SSELineParser:
- _data: list[str]
- _event: str | None
- _retry: int | None
- _id: str | None
-
- def __init__(self):
- self._event = None
- self._data = []
- self._id = None
- self._retry = None
-
- def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
- for line in lines:
- line = line.rstrip("\n")
- if not line:
- if (
- self._event is None
- and not self._data
- and self._id is None
- and self._retry is None
- ):
- continue
- sse_event = Event(
- event=self._event,
- data="\n".join(self._data),
- id=self._id,
- retry=self._retry,
- )
- self._event = None
- self._data = []
- self._id = None
- self._retry = None
-
- yield sse_event
- self.decode_line(line)
-
- def decode_line(self, line: str):
- if line.startswith(":") or not line:
- return
-
- field, _p, value = line.partition(":")
-
- if value.startswith(" "):
- value = value[1:]
- if field == "data":
- self._data.append(value)
- elif field == "event":
- self._event = value
- elif field == "retry":
- try:
- self._retry = int(value)
- except (TypeError, ValueError):
- pass
- return
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py
deleted file mode 100644
index 6b610567..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from __future__ import annotations
-
-from collections.abc import Iterable, Mapping
-from typing import TypeVar
-
-from ._base_type import NotGiven
-
-
-def remove_notgiven_indict(obj):
- if obj is None or (not isinstance(obj, Mapping)):
- return obj
- return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
-
-
-_T = TypeVar("_T")
-
-
-def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
- return [item for sublist in t for item in sublist]
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py
deleted file mode 100644
index a0645b09..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from typing import Optional
-
-from pydantic import BaseModel
-
-from .chat_completion import CompletionChoice, CompletionUsage
-
-__all__ = ["AsyncTaskStatus"]
-
-
-class AsyncTaskStatus(BaseModel):
- id: Optional[str] = None
- request_id: Optional[str] = None
- model: Optional[str] = None
- task_status: Optional[str] = None
-
-
-class AsyncCompletion(BaseModel):
- id: Optional[str] = None
- request_id: Optional[str] = None
- model: Optional[str] = None
- task_status: str
- choices: list[CompletionChoice]
- usage: CompletionUsage
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py
deleted file mode 100644
index 4b3a929a..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from typing import Optional
-
-from pydantic import BaseModel
-
-__all__ = ["Completion", "CompletionUsage"]
-
-
-class Function(BaseModel):
- arguments: str
- name: str
-
-
-class CompletionMessageToolCall(BaseModel):
- id: str
- function: Function
- type: str
-
-
-class CompletionMessage(BaseModel):
- content: Optional[str] = None
- role: str
- tool_calls: Optional[list[CompletionMessageToolCall]] = None
-
-
-class CompletionUsage(BaseModel):
- prompt_tokens: int
- completion_tokens: int
- total_tokens: int
-
-
-class CompletionChoice(BaseModel):
- index: int
- finish_reason: str
- message: CompletionMessage
-
-
-class Completion(BaseModel):
- model: Optional[str] = None
- created: Optional[int] = None
- choices: list[CompletionChoice]
- request_id: Optional[str] = None
- id: Optional[str] = None
- usage: CompletionUsage
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py
deleted file mode 100644
index c2506997..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from typing import Optional
-
-from pydantic import BaseModel
-
-__all__ = [
- "ChatCompletionChunk",
- "Choice",
- "ChoiceDelta",
- "ChoiceDeltaFunctionCall",
- "ChoiceDeltaToolCall",
- "ChoiceDeltaToolCallFunction",
-]
-
-
-class ChoiceDeltaFunctionCall(BaseModel):
- arguments: Optional[str] = None
- name: Optional[str] = None
-
-
-class ChoiceDeltaToolCallFunction(BaseModel):
- arguments: Optional[str] = None
- name: Optional[str] = None
-
-
-class ChoiceDeltaToolCall(BaseModel):
- index: int
- id: Optional[str] = None
- function: Optional[ChoiceDeltaToolCallFunction] = None
- type: Optional[str] = None
-
-
-class ChoiceDelta(BaseModel):
- content: Optional[str] = None
- role: Optional[str] = None
- tool_calls: Optional[list[ChoiceDeltaToolCall]] = None
-
-
-class Choice(BaseModel):
- delta: ChoiceDelta
- finish_reason: Optional[str] = None
- index: int
-
-
-class CompletionUsage(BaseModel):
- prompt_tokens: int
- completion_tokens: int
- total_tokens: int
-
-
-class ChatCompletionChunk(BaseModel):
- id: Optional[str] = None
- choices: list[Choice]
- created: Optional[int] = None
- model: Optional[str] = None
- usage: Optional[CompletionUsage] = None
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py
deleted file mode 100644
index 6ee4dc47..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from typing import Optional
-
-from typing_extensions import TypedDict
-
-
-class Reference(TypedDict, total=False):
- enable: Optional[bool]
- search_query: Optional[str]
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py
deleted file mode 100644
index e01f2c81..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from __future__ import annotations
-
-from typing import Optional
-
-from pydantic import BaseModel
-
-from .chat.chat_completion import CompletionUsage
-
-__all__ = ["Embedding", "EmbeddingsResponded"]
-
-
-class Embedding(BaseModel):
- object: str
- index: Optional[int] = None
- embedding: list[float]
-
-
-class EmbeddingsResponded(BaseModel):
- object: str
- data: list[Embedding]
- model: str
- usage: CompletionUsage
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py
deleted file mode 100644
index 75f76fe9..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from typing import Optional
-
-from pydantic import BaseModel
-
-__all__ = ["FileObject"]
-
-
-class FileObject(BaseModel):
- id: Optional[str] = None
- bytes: Optional[int] = None
- created_at: Optional[int] = None
- filename: Optional[str] = None
- object: Optional[str] = None
- purpose: Optional[str] = None
- status: Optional[str] = None
- status_details: Optional[str] = None
-
-
-class ListOfFileObject(BaseModel):
- object: Optional[str] = None
- data: list[FileObject]
- has_more: Optional[bool] = None
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py
deleted file mode 100644
index af099189..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from __future__ import annotations
-
-from .fine_tuning_job import FineTuningJob as FineTuningJob
-from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
-from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py
deleted file mode 100644
index 1d393028..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from typing import Optional, Union
-
-from pydantic import BaseModel
-
-__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"]
-
-
-class Error(BaseModel):
- code: str
- message: str
- param: Optional[str] = None
-
-
-class Hyperparameters(BaseModel):
- n_epochs: Union[str, int, None] = None
-
-
-class FineTuningJob(BaseModel):
- id: Optional[str] = None
-
- request_id: Optional[str] = None
-
- created_at: Optional[int] = None
-
- error: Optional[Error] = None
-
- fine_tuned_model: Optional[str] = None
-
- finished_at: Optional[int] = None
-
- hyperparameters: Optional[Hyperparameters] = None
-
- model: Optional[str] = None
-
- object: Optional[str] = None
-
- result_files: list[str]
-
- status: str
-
- trained_tokens: Optional[int] = None
-
- training_file: str
-
- validation_file: Optional[str] = None
-
-
-class ListOfFineTuningJob(BaseModel):
- object: Optional[str] = None
- data: list[FineTuningJob]
- has_more: Optional[bool] = None
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py
deleted file mode 100644
index e26b4485..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from typing import Optional, Union
-
-from pydantic import BaseModel
-
-__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]
-
-
-class Metric(BaseModel):
- epoch: Optional[Union[str, int, float]] = None
- current_steps: Optional[int] = None
- total_steps: Optional[int] = None
- elapsed_time: Optional[str] = None
- remaining_time: Optional[str] = None
- trained_tokens: Optional[int] = None
- loss: Optional[Union[str, int, float]] = None
- eval_loss: Optional[Union[str, int, float]] = None
- acc: Optional[Union[str, int, float]] = None
- eval_acc: Optional[Union[str, int, float]] = None
- learning_rate: Optional[Union[str, int, float]] = None
-
-
-class JobEvent(BaseModel):
- object: Optional[str] = None
- id: Optional[str] = None
- type: Optional[str] = None
- created_at: Optional[int] = None
- level: Optional[str] = None
- message: Optional[str] = None
- data: Optional[Metric] = None
-
-
-class FineTuningJobEvent(BaseModel):
- object: Optional[str] = None
- data: list[JobEvent]
- has_more: Optional[bool] = None
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py
deleted file mode 100644
index e1ebc352..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from __future__ import annotations
-
-from typing import Literal, Union
-
-from typing_extensions import TypedDict
-
-__all__ = ["Hyperparameters"]
-
-
-class Hyperparameters(TypedDict, total=False):
- batch_size: Union[Literal["auto"], int]
-
- learning_rate_multiplier: Union[Literal["auto"], float]
-
- n_epochs: Union[Literal["auto"], int]
diff --git a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py b/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py
deleted file mode 100644
index b352ce09..00000000
--- a/model-providers/model_providers/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from __future__ import annotations
-
-from typing import Optional
-
-from pydantic import BaseModel
-
-__all__ = ["GeneratedImage", "ImagesResponded"]
-
-
-class GeneratedImage(BaseModel):
- b64_json: Optional[str] = None
- url: Optional[str] = None
- revised_prompt: Optional[str] = None
-
-
-class ImagesResponded(BaseModel):
- created: int
- data: list[GeneratedImage]
diff --git a/model-providers/model_providers/core/model_runtime/schema_validators/common_validator.py b/model-providers/model_providers/core/model_runtime/schema_validators/common_validator.py
index 8d56fb65..4b2ee2ee 100644
--- a/model-providers/model_providers/core/model_runtime/schema_validators/common_validator.py
+++ b/model-providers/model_providers/core/model_runtime/schema_validators/common_validator.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import List, Optional
from model_providers.core.model_runtime.entities.provider_entities import (
CredentialFormSchema,
@@ -8,7 +8,7 @@ from model_providers.core.model_runtime.entities.provider_entities import (
class CommonValidator:
def _validate_and_filter_credential_form_schemas(
- self, credential_form_schemas: list[CredentialFormSchema], credentials: dict
+ self, credential_form_schemas: List[CredentialFormSchema], credentials: dict
) -> dict:
need_validate_credential_form_schema_map = {}
for credential_form_schema in credential_form_schemas:
diff --git a/model-providers/model_providers/core/model_runtime/utils/encoders.py b/model-providers/model_providers/core/model_runtime/utils/encoders.py
index 7c98c5e0..cea96079 100644
--- a/model-providers/model_providers/core/model_runtime/utils/encoders.py
+++ b/model-providers/model_providers/core/model_runtime/utils/encoders.py
@@ -54,7 +54,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
return float(dec_value)
-ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
+ENCODERS_BY_TYPE: Dict[type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
datetime.date: isoformat,
@@ -85,9 +85,9 @@ ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
def generate_encoders_by_class_tuples(
- type_encoder_map: dict[Any, Callable[[Any], Any]],
-) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
- encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(
+ type_encoder_map: Dict[Any, Callable[[Any], Any]],
+) -> Dict[Callable[[Any], Any], tuple[Any, ...]]:
+ encoders_by_class_tuples: Dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
@@ -117,7 +117,7 @@ def jsonable_encoder(
return encoder_instance(obj)
if isinstance(obj, BaseModel):
# TODO: remove when deprecating Pydantic v1
- encoders: dict[Any, Any] = {}
+ encoders: Dict[Any, Any] = {}
if not PYDANTIC_V2:
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
@@ -217,7 +217,7 @@ def jsonable_encoder(
try:
data = dict(obj)
except Exception as e:
- errors: list[Exception] = []
+ errors: List[Exception] = []
errors.append(e)
try:
data = vars(obj)
diff --git a/model-providers/model_providers/core/provider_manager.py b/model-providers/model_providers/core/provider_manager.py
index e8703e45..ad511337 100644
--- a/model-providers/model_providers/core/provider_manager.py
+++ b/model-providers/model_providers/core/provider_manager.py
@@ -1,9 +1,7 @@
import json
from collections import defaultdict
from json import JSONDecodeError
-from typing import Optional, Union
-
-from sqlalchemy.exc import IntegrityError
+from typing import List, Optional, Union
from model_providers.core.entities.model_entities import (
DefaultModelEntity,
@@ -201,7 +199,7 @@ class ProviderManager:
self,
provider_entity: ProviderEntity,
provider_credentials: dict,
- provider_model_records: list[dict],
+ provider_model_records: List[dict],
) -> CustomConfiguration:
"""
Convert to custom configuration.
@@ -266,8 +264,8 @@ class ProviderManager:
)
def _extract_variables(
- self, credential_form_schemas: list[CredentialFormSchema]
- ) -> list[str]:
+ self, credential_form_schemas: List[CredentialFormSchema]
+ ) -> List[str]:
"""
Extract input form variables.
diff --git a/model-providers/model_providers/core/utils/position_helper.py b/model-providers/model_providers/core/utils/position_helper.py
index 55fd754c..f6c14d81 100644
--- a/model-providers/model_providers/core/utils/position_helper.py
+++ b/model-providers/model_providers/core/utils/position_helper.py
@@ -1,8 +1,6 @@
import logging
import os
-from collections import OrderedDict
-from collections.abc import Callable
-from typing import Any, AnyStr
+from typing import Any, AnyStr, Callable, Dict, List, OrderedDict
import yaml
@@ -10,7 +8,7 @@ import yaml
def get_position_map(
folder_path: AnyStr,
file_name: str = "_position.yaml",
-) -> dict[str, int]:
+) -> Dict[str, int]:
"""
Get the mapping from name to index from a YAML file
:param folder_path:
@@ -37,10 +35,10 @@ def get_position_map(
def sort_by_position_map(
- position_map: dict[str, int],
- data: list[Any],
+ position_map: Dict[str, int],
+ data: List[Any],
name_func: Callable[[Any], str],
-) -> list[Any]:
+) -> List[Any]:
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
@@ -56,8 +54,8 @@ def sort_by_position_map(
def sort_to_dict_by_position_map(
- position_map: dict[str, int],
- data: list[Any],
+ position_map: Dict[str, int],
+ data: List[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
"""
diff --git a/model-providers/pyproject.toml b/model-providers/pyproject.toml
index 859006de..ba4d923d 100644
--- a/model-providers/pyproject.toml
+++ b/model-providers/pyproject.toml
@@ -34,8 +34,8 @@ pytest-asyncio = "^0.21.1"
grandalf = "^0.8"
pytest-profiling = "^1.7.0"
responses = "^0.25.0"
-
-
+langchain = "0.1.5"
+langchain-openai = "0.0.5"
[tool.poetry.group.lint]
optional = true
diff --git a/model-providers/tests/server_unit_test/conftest.py b/model-providers/tests/conftest.py
similarity index 77%
rename from model-providers/tests/server_unit_test/conftest.py
rename to model-providers/tests/conftest.py
index eea02a65..a4508b81 100644
--- a/model-providers/tests/server_unit_test/conftest.py
+++ b/model-providers/tests/conftest.py
@@ -6,6 +6,7 @@ from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
+from model_providers import BootstrapWebBuilder
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
@@ -102,3 +103,41 @@ def logging_conf() -> dict:
122,
111,
)
+
+@pytest.fixture
+def providers_file(request) -> str:
+ from pathlib import Path
+ import os
+ # 当前执行目录
+ # 获取当前测试文件的路径
+ test_file_path = Path(str(request.fspath)).parent
+ print("test_file_path:",test_file_path)
+ return os.path.join(test_file_path,"model_providers.yaml")
+
+
+@pytest.fixture
+@pytest.mark.requires("fastapi")
+def init_server(logging_conf: dict, providers_file: str) -> None:
+ try:
+ boot = (
+ BootstrapWebBuilder()
+ .model_providers_cfg_path(
+ model_providers_cfg_path=providers_file
+ )
+ .host(host="127.0.0.1")
+ .port(port=20000)
+ .build()
+ )
+ boot.set_app_event(started_event=None)
+ boot.logging_conf(logging_conf=logging_conf)
+ boot.run()
+
+ try:
+ yield f"http://127.0.0.1:20000"
+ finally:
+ print("")
+ boot.destroy()
+
+ except SystemExit:
+
+ raise
diff --git a/model-providers/tests/openai_providers_test/model_providers.yaml b/model-providers/tests/openai_providers_test/model_providers.yaml
new file mode 100644
index 00000000..b98d2924
--- /dev/null
+++ b/model-providers/tests/openai_providers_test/model_providers.yaml
@@ -0,0 +1,5 @@
+openai:
+ provider_credential:
+ openai_api_key: 'sk-'
+ openai_organization: ''
+ openai_api_base: ''
diff --git a/model-providers/tests/openai_providers_test/test_openai_service.py b/model-providers/tests/openai_providers_test/test_openai_service.py
new file mode 100644
index 00000000..958fa108
--- /dev/null
+++ b/model-providers/tests/openai_providers_test/test_openai_service.py
@@ -0,0 +1,36 @@
+from langchain.chains import LLMChain
+from langchain_core.prompts import PromptTemplate
+from langchain_openai import ChatOpenAI, OpenAIEmbeddings
+import pytest
+import logging
+
+logger = logging.getLogger(__name__)
+
+@pytest.mark.requires("openai")
+def test_llm(init_server: str):
+ llm = ChatOpenAI(openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/openai/v1")
+ template = """Question: {question}
+
+ Answer: Let's think step by step."""
+
+ prompt = PromptTemplate.from_template(template)
+
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ responses = llm_chain.run("你好")
+ logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
+
+
+
+
+@pytest.mark.requires("openai")
+def test_embedding(init_server: str):
+
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-large",
+ openai_api_key="YOUR_API_KEY",
+ openai_api_base=f"{init_server}/zhipuai/v1")
+
+ text = "你好"
+
+ query_result = embeddings.embed_query(text)
+
+ logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
\ No newline at end of file
diff --git a/model-providers/tests/server_unit_test/test_init_server.py b/model-providers/tests/server_unit_test/test_init_server.py
deleted file mode 100644
index 96210b89..00000000
--- a/model-providers/tests/server_unit_test/test_init_server.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import asyncio
-import logging
-
-import pytest
-
-from model_providers import BootstrapWebBuilder
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.requires("fastapi")
-def test_init_server(logging_conf: dict) -> None:
- try:
- boot = (
- BootstrapWebBuilder()
- .model_providers_cfg_path(
- model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers"
- "/model_providers.yaml"
- )
- .host(host="127.0.0.1")
- .port(port=20000)
- .build()
- )
- boot.set_app_event(started_event=None)
- boot.serve(logging_conf=logging_conf)
-
- async def pool_join_thread():
- await boot.join()
-
- asyncio.run(pool_join_thread())
- except SystemExit:
- logger.info("SystemExit raised, exiting")
- raise
diff --git a/model-providers/tests/zhipuai_providers_test/model_providers.yaml b/model-providers/tests/zhipuai_providers_test/model_providers.yaml
new file mode 100644
index 00000000..ec13ffc4
--- /dev/null
+++ b/model-providers/tests/zhipuai_providers_test/model_providers.yaml
@@ -0,0 +1,4 @@
+zhipuai:
+ provider_credential:
+ api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
+# api_base: 'https://test.bigmodel.cn/stage-api/paas/v4'
\ No newline at end of file
diff --git a/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py b/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py
new file mode 100644
index 00000000..c110b71c
--- /dev/null
+++ b/model-providers/tests/zhipuai_providers_test/test_zhipuai_service.py
@@ -0,0 +1,39 @@
+from langchain.chains import LLMChain
+from langchain_core.prompts import PromptTemplate
+from langchain_openai import ChatOpenAI, OpenAIEmbeddings
+import pytest
+import logging
+
+logger = logging.getLogger(__name__)
+
+@pytest.mark.requires("zhipuai")
+def test_llm(init_server: str):
+ llm = ChatOpenAI(
+
+ model_name="glm-4",
+ openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/zhipuai/v1")
+ template = """Question: {question}
+
+ Answer: Let's think step by step."""
+
+ prompt = PromptTemplate.from_template(template)
+
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ responses = llm_chain.run("你好")
+ logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m")
+
+
+@pytest.mark.requires("zhipuai")
+def test_embedding(init_server: str):
+
+ embeddings = OpenAIEmbeddings(model="text_embedding",
+ openai_api_key="YOUR_API_KEY",
+ openai_api_base=f"{init_server}/zhipuai/v1")
+
+ text = "你好"
+
+ query_result = embeddings.embed_query(text)
+
+ logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m")
+
+