mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
3.8兼容 (#3769)
* 增加使用说明 * 3.8兼容性配置 * fix * formater * 不同平台兼容测试用例 * embedding兼容 * 增加日志信息
This commit is contained in:
parent
4ce7ce0709
commit
2a33f9d4dd
@ -44,8 +44,8 @@ def init_server(model_platforms_shard: Dict,
|
||||
|
||||
provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
|
||||
model_platforms_shard['provider_platforms'] = provider_platforms
|
||||
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
boot.logging_conf(logging_conf=logging_conf)
|
||||
boot.run()
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
@ -42,3 +42,39 @@ make format
|
||||
make format_diff
|
||||
```
|
||||
当你对项目的一部分进行了更改,并希望确保更改的部分格式正确,而不影响代码库的其他部分时,这个命令特别有用。
|
||||
|
||||
|
||||
|
||||
### 开始使用
|
||||
|
||||
当项目安装完成,配置这个`model_providers.yaml`文件,即可完成平台加载
|
||||
> 注意: 在您配置平台之前,请确认平台依赖完整,例如智谱平台,您需要安装智谱sdk `pip install zhipuai`
|
||||
|
||||
model_providers包含了不同平台提供的 全局配置`provider_credential`,和模型配置`model_credential`
|
||||
不同平台所加载的配置有所不同,关于如何配置这个文件
|
||||
|
||||
请查看包`model_providers.core.model_runtime.model_providers`下方的平台 `yaml`文件
|
||||
例如`zhipuai.yaml`,这里给出了`provider_credential_schema`,其中包含了一个变量`api_key`
|
||||
|
||||
要加载智谱平台,操作如下
|
||||
|
||||
- 安装sdk
|
||||
```shell
|
||||
$ pip install zhipuai
|
||||
```
|
||||
|
||||
- 编辑`model_providers.yaml`
|
||||
|
||||
```yaml
|
||||
|
||||
zhipuai:
|
||||
|
||||
provider_credential:
|
||||
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.2'
|
||||
```
|
||||
|
||||
- `model-providers`可以运行pytest 测试
|
||||
```shell
|
||||
poetry run pytest tests/server_unit_test/test_init_server.py
|
||||
|
||||
```
|
||||
@ -27,3 +27,7 @@ xinference:
|
||||
model_uid: 'chatglm3-6b'
|
||||
|
||||
|
||||
zhipuai:
|
||||
|
||||
provider_credential:
|
||||
api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1'
|
||||
@ -16,7 +16,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--model-providers",
|
||||
type=str,
|
||||
default="D:\\project\\Langchain-Chatchat\\model-providers\\model_providers.yaml",
|
||||
default="/mnt/d/project/Langchain-Chatchat/model-providers/model_providers.yaml",
|
||||
help="run model_providers servers",
|
||||
dest="model_providers",
|
||||
)
|
||||
@ -36,7 +36,8 @@ if __name__ == "__main__":
|
||||
.build()
|
||||
)
|
||||
boot.set_app_event(started_event=None)
|
||||
boot.serve(logging_conf=logging_conf)
|
||||
boot.logging_conf(logging_conf=logging_conf)
|
||||
boot.run()
|
||||
|
||||
async def pool_join_thread():
|
||||
await boot.join()
|
||||
|
||||
@ -50,7 +50,7 @@ class SystemConfigurationResponse(BaseModel):
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
quota_configurations: List[QuotaConfiguration] = []
|
||||
|
||||
|
||||
class ProviderResponse(BaseModel):
|
||||
@ -65,8 +65,8 @@ class ProviderResponse(BaseModel):
|
||||
icon_large: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
supported_model_types: list[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
supported_model_types: List[ModelType]
|
||||
configurate_methods: List[ConfigurateMethod]
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
preferred_provider_type: ProviderType
|
||||
@ -114,7 +114,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ModelResponse]
|
||||
models: List[ModelResponse]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
@ -18,6 +18,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette import EventSourceResponse
|
||||
@ -67,8 +69,13 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
self._port = port
|
||||
self._router = APIRouter()
|
||||
self._app = FastAPI()
|
||||
self._logging_conf = None
|
||||
self._server = None
|
||||
self._server_thread = None
|
||||
|
||||
def logging_conf(self,logging_conf: Optional[dict] = None):
|
||||
self._logging_conf = logging_conf
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
host = cfg.get("host", "127.0.0.1")
|
||||
@ -79,7 +86,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
)
|
||||
return cls(host=host, port=port)
|
||||
|
||||
def serve(self, logging_conf: Optional[dict] = None):
|
||||
def run(self):
|
||||
self._app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@ -125,18 +132,29 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
self._app.include_router(self._router)
|
||||
|
||||
config = Config(
|
||||
app=self._app, host=self._host, port=self._port, log_config=logging_conf
|
||||
app=self._app, host=self._host, port=self._port, log_config=self._logging_conf
|
||||
)
|
||||
server = Server(config)
|
||||
self._server = Server(config)
|
||||
|
||||
def run_server():
|
||||
server.run()
|
||||
self._server.shutdown_timeout = 2 # 设置为2秒
|
||||
|
||||
self._server.run()
|
||||
|
||||
self._server_thread = threading.Thread(target=run_server)
|
||||
self._server_thread.start()
|
||||
|
||||
async def join(self):
|
||||
await self._server_thread.join()
|
||||
def destroy(self):
|
||||
|
||||
logger.info("Shutting down server")
|
||||
self._server.should_exit = True # 设置退出标志
|
||||
self._server.shutdown() # 停止服务器
|
||||
|
||||
self.join()
|
||||
|
||||
def join(self):
|
||||
self._server_thread.join()
|
||||
|
||||
|
||||
def set_app_event(self, started_event: mp.Event = None):
|
||||
@self._app.on_event("startup")
|
||||
@ -159,7 +177,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
async def list_models(self, provider: str, request: Request):
|
||||
logger.info(f"Received list_models request for provider: {provider}")
|
||||
# 返回ModelType所有的枚举
|
||||
llm_models: list[AIModelEntity] = []
|
||||
llm_models: List[AIModelEntity] = []
|
||||
for model_type in ModelType.__members__.values():
|
||||
try:
|
||||
provider_model_bundle = (
|
||||
@ -176,7 +194,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
)
|
||||
logger.error(e)
|
||||
|
||||
# models list[AIModelEntity]转换称List[ModelCard]
|
||||
# modelsList[AIModelEntity]转换称List[ModelCard]
|
||||
|
||||
models_list = [
|
||||
ModelCard(id=model.model, object=model.model_type.to_origin_model_type())
|
||||
@ -191,16 +209,41 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
logger.info(
|
||||
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||
)
|
||||
model_instance = self._provider_manager.get_model_instance(
|
||||
provider=provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embeddings_request.model,
|
||||
)
|
||||
texts = embeddings_request.input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
response = model_instance.invoke_text_embedding(texts=texts, user="abc-123")
|
||||
return await openai_embedding_text(response)
|
||||
try:
|
||||
model_instance = self._provider_manager.get_model_instance(
|
||||
provider=provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embeddings_request.model,
|
||||
)
|
||||
|
||||
# 判断embeddings_request.input是否为list
|
||||
input = ''
|
||||
if isinstance(embeddings_request.input, list):
|
||||
tokens = embeddings_request.input
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(embeddings_request.model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
for i, token in enumerate(tokens):
|
||||
text = encoding.decode(token)
|
||||
input += text
|
||||
|
||||
else:
|
||||
input = embeddings_request.input
|
||||
|
||||
response = model_instance.invoke_text_embedding(texts=[input], user="abc-123")
|
||||
return await openai_embedding_text(response)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error while creating embeddings: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except InvokeError as e:
|
||||
logger.error(f"Error while creating embeddings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self, provider: str, request: Request, chat_request: ChatCompletionRequest
|
||||
@ -254,9 +297,11 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
||||
else:
|
||||
return await openai_chat_completion(response)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error while creating chat completion: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
except InvokeError as e:
|
||||
logger.error(f"Error while creating chat completion: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
|
||||
)
|
||||
@ -273,10 +318,11 @@ def run(
|
||||
cfg=cfg.get("run_openai_api", {})
|
||||
)
|
||||
api.set_app_event(started_event=started_event)
|
||||
api.serve(logging_conf=logging_conf)
|
||||
api.logging_conf(logging_conf=logging_conf)
|
||||
api.run()
|
||||
|
||||
async def pool_join_thread():
|
||||
await api.join()
|
||||
api.join()
|
||||
|
||||
asyncio.run(pool_join_thread())
|
||||
except SystemExit:
|
||||
|
||||
@ -145,7 +145,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
top_k: Optional[float] = None
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = 256
|
||||
stop: Optional[list[str]] = None
|
||||
stop: Optional[List[str]] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
def to_model_parameters_dict(self, *args, **kwargs):
|
||||
|
||||
@ -112,7 +112,7 @@ class ProvidersWrapper:
|
||||
provider_models[model.provider.provider].append(model)
|
||||
|
||||
# convert to ProviderWithModelsResponse list
|
||||
providers_with_models: list[ProviderWithModelsResponse] = []
|
||||
providers_with_models: List[ProviderWithModelsResponse] = []
|
||||
for provider, models in provider_models.items():
|
||||
if not models:
|
||||
continue
|
||||
|
||||
@ -21,9 +21,9 @@ class ModelConfigEntity(BaseModel):
|
||||
model_schema: AIModelEntity
|
||||
mode: str
|
||||
provider_model_bundle: ProviderModelBundle
|
||||
credentials: dict[str, Any] = {}
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
credentials: Dict[str, Any] = {}
|
||||
parameters: Dict[str, Any] = {}
|
||||
stop: List[str] = []
|
||||
|
||||
|
||||
class AdvancedChatMessageEntity(BaseModel):
|
||||
@ -40,7 +40,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
|
||||
Advanced Chat Prompt Template Entity.
|
||||
"""
|
||||
|
||||
messages: list[AdvancedChatMessageEntity]
|
||||
messages: List[AdvancedChatMessageEntity]
|
||||
|
||||
|
||||
class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
@ -102,7 +102,7 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
@ -146,7 +146,7 @@ class DatasetEntity(BaseModel):
|
||||
Dataset Config Entity.
|
||||
"""
|
||||
|
||||
dataset_ids: list[str]
|
||||
dataset_ids: List[str]
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
|
||||
|
||||
@ -156,7 +156,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
@ -185,7 +185,7 @@ class AgentToolEntity(BaseModel):
|
||||
provider_type: Literal["builtin", "api"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
tool_parameters: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
@ -234,7 +234,7 @@ class AgentEntity(BaseModel):
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] = None
|
||||
tools: List[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
@ -245,7 +245,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
||||
|
||||
model_config: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
external_data_variables: list[ExternalDataVariableEntity] = []
|
||||
external_data_variables: List[ExternalDataVariableEntity] = []
|
||||
agent: Optional[AgentEntity] = None
|
||||
|
||||
# features
|
||||
@ -319,13 +319,13 @@ class ApplicationGenerateEntity(BaseModel):
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
inputs: dict[str, str]
|
||||
inputs: Dict[str, str]
|
||||
query: Optional[str] = None
|
||||
files: list[FileObj] = []
|
||||
files: List[FileObj] = []
|
||||
user_id: str
|
||||
# extras
|
||||
stream: bool
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
# extra parameters, like: auto_generate_conversation_name
|
||||
extras: dict[str, Any] = {}
|
||||
extras: Dict[str, Any] = {}
|
||||
|
||||
@ -47,12 +47,12 @@ class ImagePromptMessageFile(PromptMessageFile):
|
||||
|
||||
|
||||
class LCHumanMessageWithFiles(HumanMessage):
|
||||
# content: Union[str, list[Union[str, Dict]]]
|
||||
# content: Union[str,List[Union[str, Dict]]]
|
||||
content: str
|
||||
files: list[PromptMessageFile]
|
||||
files: List[PromptMessageFile]
|
||||
|
||||
|
||||
def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
|
||||
def lc_messages_to_prompt_messages(messages: List[BaseMessage]) -> List[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
@ -109,8 +109,8 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe
|
||||
|
||||
|
||||
def prompt_messages_to_lc_messages(
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> list[BaseMessage]:
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> List[BaseMessage]:
|
||||
messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -31,7 +31,7 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
supported_model_types: List[ModelType]
|
||||
|
||||
def __init__(self, provider_entity: ProviderEntity) -> None:
|
||||
"""
|
||||
@ -66,7 +66,7 @@ class DefaultModelProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
supported_model_types: List[ModelType]
|
||||
|
||||
|
||||
class DefaultModelEntity(BaseModel):
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -162,7 +161,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def get_provider_models(
|
||||
self, model_type: Optional[ModelType] = None, only_active: bool = False
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
) -> List[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get provider models.
|
||||
:param model_type: model type
|
||||
@ -189,8 +188,8 @@ class ProviderConfiguration(BaseModel):
|
||||
return sorted(provider_models, key=lambda x: x.model_type.value)
|
||||
|
||||
def _get_custom_provider_models(
|
||||
self, model_types: list[ModelType], provider_instance: ModelProvider
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
self, model_types: List[ModelType], provider_instance: ModelProvider
|
||||
) -> List[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get custom provider models.
|
||||
|
||||
@ -266,7 +265,7 @@ class ProviderConfigurations(BaseModel):
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
configurations: Dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -276,7 +275,7 @@ class ProviderConfigurations(BaseModel):
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
only_active: bool = False,
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
) -> List[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get available models.
|
||||
|
||||
@ -317,7 +316,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
return all_models
|
||||
|
||||
def to_list(self) -> list[ProviderConfiguration]:
|
||||
def to_list(self) -> List[ProviderConfiguration]:
|
||||
"""
|
||||
Convert to list.
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -68,7 +68,7 @@ class QuotaConfiguration(BaseModel):
|
||||
quota_limit: int
|
||||
quota_used: int
|
||||
is_valid: bool
|
||||
restrict_models: list[RestrictModel] = []
|
||||
restrict_models: List[RestrictModel] = []
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
@ -78,7 +78,7 @@ class SystemConfiguration(BaseModel):
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
quota_configurations: List[QuotaConfiguration] = []
|
||||
credentials: Optional[dict] = None
|
||||
|
||||
|
||||
@ -106,4 +106,4 @@ class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
||||
models: List[CustomModelConfiguration] = []
|
||||
|
||||
@ -68,7 +68,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
retriever_resources: List[dict]
|
||||
|
||||
|
||||
class AnnotationReplyEvent(AppQueueEvent):
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import IO, Optional, Union, cast
|
||||
from typing import IO, Generator, List, Optional, Union, cast
|
||||
|
||||
from model_providers.core.entities.provider_configuration import ProviderModelBundle
|
||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||
@ -68,13 +67,13 @@ class ModelInstance:
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -105,7 +104,7 @@ class ModelInstance:
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
self, texts: list[str], user: Optional[str] = None
|
||||
self, texts: List[str], user: Optional[str] = None
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -125,7 +124,7 @@ class ModelInstance:
|
||||
def invoke_rerank(
|
||||
self,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
docs: List[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
|
||||
@ -69,9 +69,9 @@ Model Runtime 分三层:
|
||||
|
||||
在这里我们需要先区分模型参数与模型凭据。
|
||||
|
||||
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在DifyRuntime中,他们的参数名一般为**model_parameters: dict[str, any]**。
|
||||
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在DifyRuntime中,他们的参数名一般为**model_parameters: Dict[str, any]**。
|
||||
|
||||
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中,他们的参数名一般为**credentials: dict[str, any]**,Provider层的credentials会直接被传递到这一层,不需要再单独定义。
|
||||
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中,他们的参数名一般为**credentials: Dict[str, any]**,Provider层的credentials会直接被传递到这一层,不需要再单独定义。
|
||||
|
||||
## 下一步
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@ -33,10 +33,10 @@ class Callback(ABC):
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -61,10 +61,10 @@ class Callback(ABC):
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
@ -90,10 +90,10 @@ class Callback(ABC):
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -119,10 +119,10 @@ class Callback(ABC):
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
@ -23,10 +23,10 @@ class LoggingCallback(Callback):
|
||||
llm_instance: AIModel,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -79,10 +79,10 @@ class LoggingCallback(Callback):
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
@ -109,10 +109,10 @@ class LoggingCallback(Callback):
|
||||
result: LLMResult,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -154,10 +154,10 @@ class LoggingCallback(Callback):
|
||||
ex: Exception,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
||||
@ -71,7 +71,7 @@ All models need to uniformly implement the following 2 methods:
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
@ -95,8 +95,8 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
prompt_messages:List[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -157,8 +157,8 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl
|
||||
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -196,7 +196,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
texts:List[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -230,7 +230,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl
|
||||
- Pre-calculating Tokens
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts:List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -251,7 +251,7 @@ Inherit the `__base.rerank_model.RerankModel` base class and implement the follo
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
query: str, docs:List[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
@ -498,7 +498,7 @@ class PromptMessage(ABC, BaseModel):
|
||||
Model class for prompt message.
|
||||
"""
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
|
||||
content: Optional[str |List[PromptMessageContent]] = None # Supports two types: string and content list. The content list is designed to meet the needs of multimodal inputs. For more details, see the PromptMessageContent explanation.
|
||||
name: Optional[str] = None
|
||||
```
|
||||
|
||||
@ -539,7 +539,7 @@ class AssistantPromptMessage(PromptMessage):
|
||||
function: ToolCallFunction # tool call information
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
|
||||
tool_calls:List[ToolCall] = [] # The result of tool invocation in response from the model (returned only when tools are input and the model deems it necessary to invoke a tool).
|
||||
```
|
||||
|
||||
Where `tool_calls` are the list of `tool calls` returned by the model after invoking the model with the `tools` input.
|
||||
@ -593,7 +593,7 @@ class LLMResult(BaseModel):
|
||||
Model class for llm result.
|
||||
"""
|
||||
model: str # Actual used modele
|
||||
prompt_messages: list[PromptMessage] # prompt messages
|
||||
prompt_messages:List[PromptMessage] # prompt messages
|
||||
message: AssistantPromptMessage # response message
|
||||
usage: LLMUsage # usage info
|
||||
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
|
||||
@ -624,7 +624,7 @@ class LLMResultChunk(BaseModel):
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
model: str # Actual used modele
|
||||
prompt_messages: list[PromptMessage] # prompt messages
|
||||
prompt_messages:List[PromptMessage] # prompt messages
|
||||
system_fingerprint: Optional[str] = None # request fingerprint, refer to OpenAI definition
|
||||
delta: LLMResultChunkDelta
|
||||
```
|
||||
@ -660,7 +660,7 @@ class TextEmbeddingResult(BaseModel):
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
model: str # Actual model used
|
||||
embeddings: list[list[float]] # List of embedding vectors, corresponding to the input texts list
|
||||
embeddings:List[List[float]] # List of embedding vectors, corresponding to the input texts list
|
||||
usage: EmbeddingUsage # Usage information
|
||||
```
|
||||
|
||||
@ -690,7 +690,7 @@ class RerankResult(BaseModel):
|
||||
Model class for rerank result.
|
||||
"""
|
||||
model: str # Actual model used
|
||||
docs: list[RerankDocument] # Reranked document list
|
||||
docs:List[RerankDocument] # Reranked document list
|
||||
```
|
||||
|
||||
### RerankDocument
|
||||
|
||||
@ -160,8 +160,8 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
prompt_messages:List[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -184,8 +184,8 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
|
||||
If the model does not provide a pre-calculated tokens interface, you can directly return 0.
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -226,7 +226,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -126,8 +126,8 @@ provider_credential_schema:
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
prompt_messages:List[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -166,8 +166,8 @@ provider_credential_schema:
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -283,7 +283,7 @@ provider_credential_schema:
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -80,7 +80,7 @@ class XinferenceProvider(Provider):
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
@ -95,7 +95,7 @@ class XinferenceProvider(Provider):
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
@ -127,8 +127,8 @@ class XinferenceProvider(Provider):
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
prompt_messages:List[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -189,8 +189,8 @@ class XinferenceProvider(Provider):
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -232,7 +232,7 @@ class XinferenceProvider(Provider):
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
texts:List[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -266,7 +266,7 @@ class XinferenceProvider(Provider):
|
||||
- 预计算 tokens
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts:List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -289,7 +289,7 @@ class XinferenceProvider(Provider):
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
query: str, docs:List[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
@ -538,7 +538,7 @@ class PromptMessage(ABC, BaseModel):
|
||||
Model class for prompt message.
|
||||
"""
|
||||
role: PromptMessageRole # 消息角色
|
||||
content: Optional[str | list[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
|
||||
content: Optional[str |List[PromptMessageContent]] = None # 支持两种类型,字符串和内容列表,内容列表是为了满足多模态的需要,可详见 PromptMessageContent 说明。
|
||||
name: Optional[str] = None # 名称,可选。
|
||||
```
|
||||
|
||||
@ -579,7 +579,7 @@ class AssistantPromptMessage(PromptMessage):
|
||||
function: ToolCallFunction # 工具调用信息
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools,并且模型认为需要调用工具时返回)
|
||||
tool_calls:List[ToolCall] = [] # 模型回复的工具调用结果(仅当传入 tools,并且模型认为需要调用工具时返回)
|
||||
```
|
||||
|
||||
其中 `tool_calls` 为调用模型传入 `tools` 后,由模型返回的 `tool call` 列表。
|
||||
@ -633,7 +633,7 @@ class LLMResult(BaseModel):
|
||||
Model class for llm result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
prompt_messages: list[PromptMessage] # prompt 消息列表
|
||||
prompt_messages:List[PromptMessage] # prompt 消息列表
|
||||
message: AssistantPromptMessage # 回复消息
|
||||
usage: LLMUsage # 使用的 tokens 及费用信息
|
||||
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
|
||||
@ -664,7 +664,7 @@ class LLMResultChunk(BaseModel):
|
||||
Model class for llm result chunk.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
prompt_messages: list[PromptMessage] # prompt 消息列表
|
||||
prompt_messages:List[PromptMessage] # prompt 消息列表
|
||||
system_fingerprint: Optional[str] = None # 请求指纹,可参考 OpenAI 该参数定义
|
||||
delta: LLMResultChunkDelta # 每个迭代存在变化的内容
|
||||
```
|
||||
@ -700,7 +700,7 @@ class TextEmbeddingResult(BaseModel):
|
||||
Model class for text embedding result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
embeddings: list[list[float]] # embedding 向量列表,对应传入的 texts 列表
|
||||
embeddings:List[List[float]] # embedding 向量列表,对应传入的 texts 列表
|
||||
usage: EmbeddingUsage # 使用信息
|
||||
```
|
||||
|
||||
@ -730,7 +730,7 @@ class RerankResult(BaseModel):
|
||||
Model class for rerank result.
|
||||
"""
|
||||
model: str # 实际使用模型
|
||||
docs: list[RerankDocument] # 重排后的分段列表
|
||||
docs:List[RerankDocument] # 重排后的分段列表
|
||||
```
|
||||
|
||||
### RerankDocument
|
||||
|
||||
@ -76,8 +76,8 @@ pricing: # 价格信息
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
prompt_messages:List[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[List[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -116,8 +116,8 @@ pricing: # 价格信息
|
||||
若模型未提供预计算 tokens 接口,可直接返回 0。
|
||||
|
||||
```python
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages:List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -158,7 +158,7 @@ pricing: # 价格信息
|
||||
|
||||
```python
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[type[InvokeError],List[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from typing import Dict
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
DefaultParameterName,
|
||||
)
|
||||
|
||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||
PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
|
||||
DefaultParameterName.TEMPERATURE: {
|
||||
"label": {
|
||||
"en_US": "Temperature",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -78,7 +78,7 @@ class LLMResult(BaseModel):
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
prompt_messages: List[PromptMessage]
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage
|
||||
system_fingerprint: Optional[str] = None
|
||||
@ -101,7 +101,7 @@ class LLMResultChunk(BaseModel):
|
||||
"""
|
||||
|
||||
model: str
|
||||
prompt_messages: list[PromptMessage]
|
||||
prompt_messages: List[PromptMessage]
|
||||
system_fingerprint: Optional[str] = None
|
||||
delta: LLMResultChunkDelta
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -110,7 +110,7 @@ class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None
|
||||
content: Optional[Union[str, List[PromptMessageContent]]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@ -145,7 +145,7 @@ class AssistantPromptMessage(PromptMessage):
|
||||
function: ToolCallFunction
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
tool_calls: List[ToolCall] = []
|
||||
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -158,9 +158,9 @@ class ProviderModel(BaseModel):
|
||||
model: str
|
||||
label: I18nObject
|
||||
model_type: ModelType
|
||||
features: Optional[list[ModelFeature]] = None
|
||||
features: Optional[List[ModelFeature]] = None
|
||||
fetch_from: FetchFrom
|
||||
model_properties: dict[ModelPropertyKey, Any]
|
||||
model_properties: Dict[ModelPropertyKey, Any]
|
||||
deprecated: bool = False
|
||||
|
||||
class Config:
|
||||
@ -182,7 +182,7 @@ class ParameterRule(BaseModel):
|
||||
min: Optional[float] = None
|
||||
max: Optional[float] = None
|
||||
precision: Optional[int] = None
|
||||
options: list[str] = []
|
||||
options: List[str] = []
|
||||
|
||||
|
||||
class PriceConfig(BaseModel):
|
||||
@ -201,7 +201,7 @@ class AIModelEntity(ProviderModel):
|
||||
Model class for AI model.
|
||||
"""
|
||||
|
||||
parameter_rules: list[ParameterRule] = []
|
||||
parameter_rules: List[ParameterRule] = []
|
||||
pricing: Optional[PriceConfig] = None
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -48,7 +48,7 @@ class FormOption(BaseModel):
|
||||
|
||||
label: I18nObject
|
||||
value: str
|
||||
show_on: list[FormShowOnObject] = []
|
||||
show_on: List[FormShowOnObject] = []
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@ -66,10 +66,10 @@ class CredentialFormSchema(BaseModel):
|
||||
type: FormType
|
||||
required: bool = True
|
||||
default: Optional[str] = None
|
||||
options: Optional[list[FormOption]] = None
|
||||
options: Optional[List[FormOption]] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
max_length: int = 0
|
||||
show_on: list[FormShowOnObject] = []
|
||||
show_on: List[FormShowOnObject] = []
|
||||
|
||||
|
||||
class ProviderCredentialSchema(BaseModel):
|
||||
@ -77,7 +77,7 @@ class ProviderCredentialSchema(BaseModel):
|
||||
Model class for provider credential schema.
|
||||
"""
|
||||
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
credential_form_schemas: List[CredentialFormSchema]
|
||||
|
||||
|
||||
class FieldModelSchema(BaseModel):
|
||||
@ -91,7 +91,7 @@ class ModelCredentialSchema(BaseModel):
|
||||
"""
|
||||
|
||||
model: FieldModelSchema
|
||||
credential_form_schemas: list[CredentialFormSchema]
|
||||
credential_form_schemas: List[CredentialFormSchema]
|
||||
|
||||
|
||||
class SimpleProviderEntity(BaseModel):
|
||||
@ -103,8 +103,8 @@ class SimpleProviderEntity(BaseModel):
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
supported_model_types: list[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
supported_model_types: List[ModelType]
|
||||
models: List[AIModelEntity] = []
|
||||
|
||||
|
||||
class ProviderHelpEntity(BaseModel):
|
||||
@ -128,9 +128,9 @@ class ProviderEntity(BaseModel):
|
||||
icon_large: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
supported_model_types: list[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
models: list[ProviderModel] = []
|
||||
supported_model_types: List[ModelType]
|
||||
configurate_methods: List[ConfigurateMethod]
|
||||
models: List[ProviderModel] = []
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -17,4 +19,4 @@ class RerankResult(BaseModel):
|
||||
"""
|
||||
|
||||
model: str
|
||||
docs: list[RerankDocument]
|
||||
docs: List[RerankDocument]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from decimal import Decimal
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -25,5 +26,5 @@ class TextEmbeddingResult(BaseModel):
|
||||
"""
|
||||
|
||||
model: str
|
||||
embeddings: list[list[float]]
|
||||
embeddings: List[List[float]]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import decimal
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
import yaml
|
||||
|
||||
@ -35,7 +35,7 @@ class AIModel(ABC):
|
||||
"""
|
||||
|
||||
model_type: ModelType
|
||||
model_schemas: list[AIModelEntity] = None
|
||||
model_schemas: List[AIModelEntity] = None
|
||||
started_at: float = 0
|
||||
|
||||
@abstractmethod
|
||||
@ -51,7 +51,7 @@ class AIModel(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
@ -133,7 +133,7 @@ class AIModel(ABC):
|
||||
currency=price_config.currency,
|
||||
)
|
||||
|
||||
def predefined_models(self) -> list[AIModelEntity]:
|
||||
def predefined_models(self) -> List[AIModelEntity]:
|
||||
"""
|
||||
Get all predefined models for given provider.
|
||||
|
||||
|
||||
@ -3,8 +3,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||
from model_providers.core.model_runtime.callbacks.logging_callback import (
|
||||
@ -47,13 +46,13 @@ class LargeLanguageModel(AIModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -170,13 +169,13 @@ class LargeLanguageModel(AIModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
||||
@ -290,7 +289,7 @@ if you are not sure about the structure.
|
||||
def _code_block_mode_stream_processor(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
input_generator: Generator[LLMResultChunk, None, None],
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
@ -428,13 +427,13 @@ if you are not sure about the structure.
|
||||
model: str,
|
||||
result: Generator,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
Invoke result generator
|
||||
@ -498,10 +497,10 @@ if you are not sure about the structure.
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -525,8 +524,8 @@ if you are not sure about the structure.
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -539,7 +538,7 @@ if you are not sure about the structure.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
|
||||
def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
|
||||
"""Cut off the text as soon as any stop words occur."""
|
||||
return re.split("|".join(stop), text, maxsplit=1)[0]
|
||||
|
||||
@ -575,7 +574,7 @@ if you are not sure about the structure.
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRule]:
|
||||
def get_parameter_rules(self, model: str, credentials: dict) -> List[ParameterRule]:
|
||||
"""
|
||||
Get parameter rules
|
||||
|
||||
@ -658,13 +657,13 @@ if you are not sure about the structure.
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
@ -706,13 +705,13 @@ if you are not sure about the structure.
|
||||
chunk: LLMResultChunk,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
@ -755,13 +754,13 @@ if you are not sure about the structure.
|
||||
model: str,
|
||||
result: LLMResult,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
@ -805,13 +804,13 @@ if you are not sure about the structure.
|
||||
model: str,
|
||||
ex: Exception,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
@ -911,7 +910,7 @@ if you are not sure about the structure.
|
||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
|
||||
)
|
||||
elif parameter_rule.type == ParameterType.FLOAT:
|
||||
if not isinstance(parameter_value, float | int):
|
||||
if not isinstance(parameter_value, (float, int)):
|
||||
raise ValueError(
|
||||
f"Model Parameter {parameter_name} should be float."
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
import yaml
|
||||
|
||||
@ -14,7 +15,7 @@ from model_providers.core.model_runtime.model_providers.__base.ai_model import A
|
||||
|
||||
class ModelProvider(ABC):
|
||||
provider_schema: ProviderEntity = None
|
||||
model_instance_map: dict[str, AIModel] = {}
|
||||
model_instance_map: Dict[str, AIModel] = {}
|
||||
|
||||
@abstractmethod
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
@ -65,7 +66,7 @@ class ModelProvider(ABC):
|
||||
|
||||
return provider_schema
|
||||
|
||||
def models(self, model_type: ModelType) -> list[AIModelEntity]:
|
||||
def models(self, model_type: ModelType) -> List[AIModelEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
|
||||
@ -19,7 +19,7 @@ class RerankModel(AIModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
docs: List[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
@ -51,7 +51,7 @@ class RerankModel(AIModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
docs: List[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import IO, Optional
|
||||
from typing import IO, List, Optional
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@ -19,7 +19,7 @@ class Text2ImageModel(AIModel):
|
||||
prompt: str,
|
||||
model_parameters: dict,
|
||||
user: Optional[str] = None,
|
||||
) -> list[IO[bytes]]:
|
||||
) -> List[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
@ -44,7 +44,7 @@ class Text2ImageModel(AIModel):
|
||||
prompt: str,
|
||||
model_parameters: dict,
|
||||
user: Optional[str] = None,
|
||||
) -> list[IO[bytes]]:
|
||||
) -> List[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import (
|
||||
ModelPropertyKey,
|
||||
@ -23,7 +23,7 @@ class TextEmbeddingModel(AIModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -62,7 +62,7 @@ class TextEmbeddingModel(AIModel):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Dict, Generator, List, Optional, Type, Union, cast
|
||||
|
||||
import anthropic
|
||||
import requests
|
||||
@ -63,10 +62,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -92,9 +91,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -158,13 +157,13 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
@ -203,12 +202,12 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
user: Union[str, None] = None,
|
||||
response_format: str = "JSON",
|
||||
) -> None:
|
||||
"""
|
||||
@ -251,8 +250,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -297,7 +296,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Message,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
@ -345,7 +344,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[MessageStreamEvent],
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
@ -424,8 +423,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
return credentials_kwargs
|
||||
|
||||
def _convert_prompt_messages(
|
||||
self, prompt_messages: list[PromptMessage]
|
||||
) -> tuple[str, list[dict]]:
|
||||
self, prompt_messages: List[PromptMessage]
|
||||
) -> tuple[str, List[dict]]:
|
||||
"""
|
||||
Convert prompt messages to dict list and system
|
||||
"""
|
||||
@ -560,7 +559,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt_anthropic(
|
||||
self, messages: list[PromptMessage]
|
||||
self, messages: List[PromptMessage]
|
||||
) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
@ -583,7 +582,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
return text.rstrip()
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
@ -29,7 +31,7 @@ class _CommonAzureOpenAI:
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import copy
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import AzureOpenAI, Stream
|
||||
@ -60,10 +59,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -102,8 +101,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
model_mode = self._get_ai_model_entity(
|
||||
credentials.get("base_model_name"), model
|
||||
@ -176,9 +175,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -215,7 +214,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Completion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
assistant_text = response.choices[0].text
|
||||
|
||||
@ -257,7 +256,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[Completion],
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
full_text = ""
|
||||
for chunk in response:
|
||||
@ -321,10 +320,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -381,8 +380,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> LLMResult:
|
||||
assistant_message = response.choices[0].message
|
||||
# assistant_message_tool_calls = assistant_message.tool_calls
|
||||
@ -435,8 +434,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> Generator:
|
||||
index = 0
|
||||
full_assistant_content = ""
|
||||
@ -545,8 +544,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_tool_calls(
|
||||
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
response_tool_calls: List[
|
||||
Union[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
|
||||
],
|
||||
) -> List[AssistantPromptMessage.ToolCall]:
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
@ -566,7 +567,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_function_call(
|
||||
response_function_call: FunctionCall | ChoiceDeltaFunctionCall,
|
||||
response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall],
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
@ -651,7 +652,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
credentials: dict,
|
||||
text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
|
||||
@ -668,8 +669,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
credentials: dict,
|
||||
messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
@ -743,7 +744,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
@staticmethod
|
||||
def _num_tokens_for_tools(
|
||||
encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
|
||||
encoding: tiktoken.Encoding, tools: List[PromptMessageTool]
|
||||
) -> int:
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import copy
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
@ -35,7 +35,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
base_model_name = credentials["base_model_name"]
|
||||
@ -51,7 +51,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
@ -81,8 +81,8 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
results: List[List[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
@ -112,7 +112,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
embeddings=embeddings, usage=usage, model=base_model_name
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
@ -168,9 +168,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
def _embedding_invoke(
|
||||
model: str,
|
||||
client: AzureOpenAI,
|
||||
texts: Union[list[str], str],
|
||||
texts: Union[List[str], str],
|
||||
extra_model_kwargs: dict,
|
||||
) -> tuple[list[list[float]], int]:
|
||||
) -> tuple[List[list[float]], int]:
|
||||
response = client.embeddings.create(
|
||||
input=texts,
|
||||
model=model,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
from requests import post
|
||||
|
||||
@ -25,10 +24,10 @@ class BaichuanMessage:
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
usage: Dict[str, int] = None
|
||||
stop_reason: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
@ -112,9 +111,9 @@ class BaichuanModel:
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
messages: List[BaichuanMessage],
|
||||
parameters: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
if (
|
||||
model == "baichuan2-turbo"
|
||||
or model == "baichuan2-turbo-192k"
|
||||
@ -165,7 +164,7 @@ class BaichuanModel:
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _build_headers(self, model: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if (
|
||||
model == "baichuan2-turbo"
|
||||
or model == "baichuan2-turbo-192k"
|
||||
@ -187,8 +186,8 @@ class BaichuanModel:
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any],
|
||||
messages: List[BaichuanMessage],
|
||||
parameters: Dict[str, Any],
|
||||
timeout: int,
|
||||
) -> Union[Generator, BaichuanMessage]:
|
||||
if (
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Dict, Generator, List, Type, Union, cast
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@ -49,13 +48,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
@ -71,14 +70,14 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
messages: List[PromptMessage],
|
||||
) -> int:
|
||||
"""Calculate num tokens for baichuan model"""
|
||||
|
||||
@ -149,13 +148,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
if tools is not None and len(tools) > 0:
|
||||
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
||||
|
||||
@ -195,7 +194,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: BaichuanMessage,
|
||||
) -> LLMResult:
|
||||
@ -216,7 +215,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[BaichuanMessage, None, None],
|
||||
) -> Generator:
|
||||
@ -258,7 +257,7 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -47,7 +47,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -93,8 +93,8 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
return result
|
||||
|
||||
def embedding(
|
||||
self, model: str, api_key, texts: list[str], user: Optional[str] = None
|
||||
) -> tuple[list[list[float]], int]:
|
||||
self, model: str, api_key, texts: List[str], user: Optional[str] = None
|
||||
) -> tuple[List[list[float]], int]:
|
||||
"""
|
||||
Embed given texts
|
||||
|
||||
@ -154,7 +154,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return [data["embedding"] for data in embeddings], usage["total_tokens"]
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -183,7 +183,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
@ -48,10 +47,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -77,8 +76,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
messages: list[PromptMessage] | str,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
messages: Union[List[PromptMessage], str],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -99,7 +98,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(
|
||||
self, model_prefix: str, messages: list[PromptMessage]
|
||||
self, model_prefix: str, messages: List[PromptMessage]
|
||||
) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
@ -190,7 +189,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(
|
||||
self, messages: list[PromptMessage], model_prefix: str
|
||||
self, messages: List[PromptMessage], model_prefix: str
|
||||
) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
|
||||
@ -216,9 +215,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
def _create_payload(
|
||||
self,
|
||||
model_prefix: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -282,9 +281,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -356,7 +355,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@ -436,7 +435,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -551,7 +550,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||
@ -570,7 +569,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _map_client_to_invoke_error(
|
||||
self, error_code: str, error_msg: str
|
||||
) -> type[InvokeError]:
|
||||
) -> Type[InvokeError]:
|
||||
"""
|
||||
Map client error to invoke error
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 5.4 KiB |
@ -1,9 +0,0 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<mask id="mask0_8587_60212" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="1" y="2" width="23" height="21">
|
||||
<path d="M23.8 2H1V22.4H23.8V2Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask0_8587_60212)">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.86378 14.4544C3.86378 13.0981 4.67438 11.737 6.25923 10.6634C7.83827 9.59364 10.0864 8.89368 12.6282 8.89368C15.17 8.89368 17.4182 9.59364 18.9972 10.6634C19.7966 11.2049 20.399 11.8196 20.7998 12.4699C21.2873 11.5802 21.4969 10.6351 21.3835 9.69252C21.3759 9.62928 21.3824 9.56766 21.4005 9.5106C21.0758 9.21852 20.7259 8.94624 20.3558 8.69556C18.3272 7.32126 15.5915 6.50964 12.6282 6.50964C9.66497 6.50964 6.92918 7.32126 4.90058 8.69556C2.8778 10.0659 1.45703 12.0812 1.45703 14.4544C1.45703 16.8275 2.8778 18.8428 4.90058 20.2132C6.92918 21.5875 9.66497 22.3991 12.6282 22.3991C15.5915 22.3991 18.3272 21.5875 20.3558 20.2132C22.3786 18.8428 23.7994 16.8275 23.7994 14.4544C23.7994 12.9455 23.225 11.5813 22.2868 10.4355C22.2377 11.4917 21.8621 12.5072 21.238 13.43C21.3409 13.7686 21.3926 14.1116 21.3926 14.4544C21.3926 15.8107 20.582 17.1717 18.9972 18.2453C17.4182 19.3151 15.17 20.015 12.6282 20.015C10.0864 20.015 7.83827 19.3151 6.25923 18.2453C4.67438 17.1717 3.86378 15.8107 3.86378 14.4544Z" fill="#3762FF"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.84445 11.6838C3.20239 13.4885 3.35368 15.1156 4.18868 16.2838C5.02368 17.452 6.52281 18.1339 8.45459 18.1334C10.3826 18.133 12.6296 17.44 14.6939 15.9922C16.7581 14.5444 18.1643 12.6753 18.8052 10.8739C19.4473 9.0692 19.2959 7.44206 18.461 6.27392C17.626 5.10572 16.1269 4.42389 14.1951 4.42431C12.267 4.42475 10.0201 5.11774 7.95575 6.56552C5.89152 8.01332 4.48529 9.8825 3.84445 11.6838ZM1.53559 10.8778C2.36374 8.55002 4.11254 6.28976 6.54117 4.58645C8.96981 2.88312 11.7029 1.99995 14.1945 1.99939C16.6825 1.99884 19.0426 2.8912 20.4589 4.87263C21.8752 6.85406 21.941 9.35564 21.1141 11.6799C20.2859 14.0077 18.5371 16.2679 16.1085 17.9713C13.6798 19.6746 10.9468 20.5578 8.45513 20.5584C5.9672 20.5589 3.60706 19.6665 2.19075 17.6851C0.774446 15.7036 0.708677 13.2021 1.53559 10.8778Z" fill="#1041F3"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.2 KiB |
@ -1,36 +0,0 @@
|
||||
import logging
|
||||
|
||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from model_providers.core.model_runtime.errors.validate import (
|
||||
CredentialsValidateFailedError,
|
||||
)
|
||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGLMProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `chatglm3-6b` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model="chatglm3-6b", credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(
|
||||
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||
)
|
||||
raise ex
|
||||
@ -1,28 +0,0 @@
|
||||
provider: chatglm
|
||||
label:
|
||||
en_US: ChatGLM
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#F4F7FF"
|
||||
help:
|
||||
title:
|
||||
en_US: Deploy ChatGLM to your local
|
||||
zh_Hans: 部署您的本地 ChatGLM
|
||||
url:
|
||||
en_US: https://github.com/THUDM/ChatGLM3
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_base
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
en_US: Enter your API URL
|
||||
@ -1,21 +0,0 @@
|
||||
model: chatglm2-6b-32k
|
||||
label:
|
||||
en_US: ChatGLM2-6B-32K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2000
|
||||
min: 1
|
||||
max: 32000
|
||||
@ -1,21 +0,0 @@
|
||||
model: chatglm2-6b
|
||||
label:
|
||||
en_US: ChatGLM2-6B
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 256
|
||||
min: 1
|
||||
max: 2000
|
||||
@ -1,22 +0,0 @@
|
||||
model: chatglm3-6b-32k
|
||||
label:
|
||||
en_US: ChatGLM3-6B-32K
|
||||
model_type: llm
|
||||
features:
|
||||
- tool-call
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
@ -1,22 +0,0 @@
|
||||
model: chatglm3-6b
|
||||
label:
|
||||
en_US: ChatGLM3-6B
|
||||
model_type: llm
|
||||
features:
|
||||
- tool-call
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 256
|
||||
min: 1
|
||||
max: 8000
|
||||
@ -1,555 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from os.path import join
|
||||
from typing import Optional, cast
|
||||
|
||||
from httpx import Timeout
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APITimeoutError,
|
||||
AuthenticationError,
|
||||
ConflictError,
|
||||
InternalServerError,
|
||||
NotFoundError,
|
||||
OpenAI,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
Stream,
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from model_providers.core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from model_providers.core.model_runtime.errors.validate import (
|
||||
CredentialsValidateFailedError,
|
||||
)
|
||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||
LargeLanguageModel,
|
||||
)
|
||||
from model_providers.core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# invoke model
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
return self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[
|
||||
UserPromptMessage(content="ping"),
|
||||
],
|
||||
model_parameters={
|
||||
"max_tokens": 16,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(str(e))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
APIConnectionError,
|
||||
APITimeoutError,
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError,
|
||||
ConflictError,
|
||||
NotFoundError,
|
||||
UnprocessableEntityError,
|
||||
PermissionDeniedError,
|
||||
],
|
||||
InvokeRateLimitError: [RateLimitError],
|
||||
InvokeAuthorizationError: [AuthenticationError],
|
||||
InvokeBadRequestError: [ValueError],
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
self._check_chatglm_parameters(
|
||||
model=model, model_parameters=model_parameters, tools=tools
|
||||
)
|
||||
|
||||
kwargs = self._to_client_kwargs(credentials)
|
||||
# init model client
|
||||
client = OpenAI(**kwargs)
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs["stop"] = stop
|
||||
|
||||
if user:
|
||||
extra_model_kwargs["user"] = user
|
||||
|
||||
if tools and len(tools) > 0:
|
||||
extra_model_kwargs["functions"] = [
|
||||
helper.dump_model(tool) for tool in tools
|
||||
]
|
||||
|
||||
result = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
response=result,
|
||||
tools=tools,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
return self._handle_chat_generate_response(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
response=result,
|
||||
tools=tools,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
def _check_chatglm_parameters(
|
||||
self, model: str, model_parameters: dict, tools: list[PromptMessageTool]
|
||||
) -> None:
|
||||
if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0:
|
||||
raise InvokeBadRequestError("ChatGLM2 does not support function calling")
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI Compatibility API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError("User message content must be str")
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
message_dict["function_call"] = {
|
||||
"name": message.tool_calls[0].function.name,
|
||||
"arguments": message.tool_calls[0].function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
# check if last message is user message
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def _extract_response_tool_calls(
|
||||
self, response_function_calls: list[FunctionCall]
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_function_calls:
|
||||
for response_tool_call in response_function_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call.name, arguments=response_tool_call.arguments
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=0, type="function", function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _to_client_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Convert invoke kwargs to client kwargs
|
||||
|
||||
:param stream: is stream response
|
||||
:param model_name: model name
|
||||
:param credentials: credentials dict
|
||||
:param model_parameters: model parameters
|
||||
:return: client kwargs
|
||||
"""
|
||||
client_kwargs = {
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"api_key": "1",
|
||||
"base_url": join(credentials["api_base"], "v1"),
|
||||
}
|
||||
|
||||
return client_kwargs
|
||||
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Generator:
|
||||
full_response = ""
|
||||
|
||||
for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (
|
||||
delta.delta.content is None or delta.delta.content == ""
|
||||
):
|
||||
continue
|
||||
|
||||
# check if there is a tool call in the response
|
||||
function_calls = None
|
||||
if delta.delta.function_call:
|
||||
function_calls = [delta.delta.function_call]
|
||||
|
||||
assistant_message_tool_calls = self._extract_response_tool_calls(
|
||||
function_calls if function_calls else []
|
||||
)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else "",
|
||||
tool_calls=assistant_message_tool_calls,
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
# temp_assistant_prompt_message is used to calculate usage
|
||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_response, tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(
|
||||
messages=prompt_messages, tools=tools
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(
|
||||
messages=[temp_assistant_prompt_message], tools=[]
|
||||
)
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=chunk.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
full_response += delta.delta.content
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return: llm response
|
||||
"""
|
||||
if len(response.choices) == 0:
|
||||
raise InvokeServerUnavailableError("Empty response")
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# convert function call to tool call
|
||||
function_calls = assistant_message.function_call
|
||||
tool_calls = self._extract_response_tool_calls(
|
||||
[function_calls] if function_calls else []
|
||||
)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_message.content, tool_calls=tool_calls
|
||||
)
|
||||
|
||||
prompt_tokens = self._num_tokens_from_messages(
|
||||
messages=prompt_messages, tools=tools
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(
|
||||
messages=[assistant_prompt_message], tools=tools
|
||||
)
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
usage=usage,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _num_tokens_from_string(
|
||||
self, text: str, tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer.
|
||||
|
||||
it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer,
|
||||
As a temporary solution we use GPT2 tokenizer instead.
|
||||
|
||||
"""
|
||||
|
||||
def tokens(text: str):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
value = text
|
||||
|
||||
if key == "function_call":
|
||||
for t_key, t_value in value.items():
|
||||
num_tokens += tokens(t_key)
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += tokens(f_key)
|
||||
num_tokens += tokens(f_value)
|
||||
else:
|
||||
num_tokens += tokens(t_key)
|
||||
num_tokens += tokens(t_value)
|
||||
else:
|
||||
num_tokens += tokens(str(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling
|
||||
|
||||
:param encoding: encoding
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
|
||||
def tokens(text: str):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
# calculate num tokens for function object
|
||||
num_tokens += tokens("name")
|
||||
num_tokens += tokens(tool.name)
|
||||
num_tokens += tokens("description")
|
||||
num_tokens += tokens(tool.description)
|
||||
parameters = tool.parameters
|
||||
num_tokens += tokens("parameters")
|
||||
num_tokens += tokens("type")
|
||||
num_tokens += tokens(parameters.get("type"))
|
||||
if "properties" in parameters:
|
||||
num_tokens += tokens("properties")
|
||||
for key, value in parameters.get("properties").items():
|
||||
num_tokens += tokens(key)
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += tokens(field_key)
|
||||
if field_key == "enum":
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += tokens(enum_field)
|
||||
else:
|
||||
num_tokens += tokens(field_key)
|
||||
num_tokens += tokens(str(field_value))
|
||||
if "required" in parameters:
|
||||
num_tokens += tokens("required")
|
||||
for required_field in parameters["required"]:
|
||||
num_tokens += 3
|
||||
num_tokens += tokens(required_field)
|
||||
|
||||
return num_tokens
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Dict, List, Optional, Type, Union, cast
|
||||
|
||||
import cohere
|
||||
from cohere.responses import Chat, Generations
|
||||
@ -55,10 +55,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -103,8 +103,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -171,9 +171,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -216,7 +216,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Generations,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@ -256,7 +256,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: StreamingGenerations,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -317,9 +317,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -377,8 +377,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Chat,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
@ -429,8 +429,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: StreamingChat,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
@ -517,8 +517,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
index += 1
|
||||
|
||||
def _convert_prompt_messages_to_message_and_chat_histories(
|
||||
self, prompt_messages: list[PromptMessage]
|
||||
) -> tuple[str, list[dict]]:
|
||||
self, prompt_messages: List[PromptMessage]
|
||||
) -> tuple[str, List[dict]]:
|
||||
"""
|
||||
Convert prompt messages to message and chat histories
|
||||
:param prompt_messages: prompt messages
|
||||
@ -586,7 +586,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
return response.length
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self, model: str, credentials: dict, messages: list[PromptMessage]
|
||||
self, model: str, credentials: dict, messages: List[PromptMessage]
|
||||
) -> int:
|
||||
"""Calculate num tokens Cohere model."""
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
@ -650,7 +650,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
return entity
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
import cohere
|
||||
|
||||
@ -32,7 +32,7 @@ class CohereRerankModel(RerankModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
docs: List[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
@ -99,7 +99,7 @@ class CohereRerankModel(RerankModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -35,7 +35,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -51,7 +51,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
@ -79,8 +79,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
results: List[List[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
@ -105,7 +105,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -161,8 +161,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(
|
||||
self, model: str, credentials: dict, texts: list[str]
|
||||
) -> tuple[list[list[float]], int]:
|
||||
self, model: str, credentials: dict, texts: List[str]
|
||||
) -> tuple[List[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
@ -216,7 +216,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
import google.api_core.exceptions as exceptions
|
||||
import google.generativeai as genai
|
||||
@ -66,10 +65,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -102,8 +101,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -118,7 +117,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
|
||||
@ -155,10 +154,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -249,7 +248,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: GenerateContentResponse,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@ -306,7 +305,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: GenerateContentResponse,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -416,7 +415,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
return glm_content
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||
@ -472,8 +471,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
from google.protobuf import json_format
|
||||
|
||||
if isinstance(response_function_call, FunctionCall):
|
||||
map_composite_dict = dict(response_function_call.args.items())
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
@ -16,10 +15,10 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
|
||||
|
||||
from model_providers.core.model_runtime.errors.invoke import (
|
||||
@ -8,5 +10,5 @@ from model_providers.core.model_runtime.errors.invoke import (
|
||||
|
||||
class _CommonHuggingfaceHub:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]}
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
@ -44,10 +43,10 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -81,8 +80,8 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
@ -161,7 +160,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def _get_customizable_model_parameter_rules() -> list[ParameterRule]:
|
||||
def _get_customizable_model_parameter_rules() -> List[ParameterRule]:
|
||||
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(
|
||||
DefaultParameterName.TEMPERATURE
|
||||
).copy()
|
||||
@ -253,7 +252,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
response: Generator,
|
||||
) -> Generator:
|
||||
index = -1
|
||||
@ -300,7 +299,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
response: any,
|
||||
) -> LLMResult:
|
||||
if isinstance(response, str):
|
||||
@ -355,7 +354,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
|
||||
return model_info.pipeline_tag
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -35,7 +35,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||
@ -62,7 +62,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
embeddings=self._mean_pooling(embeddings), usage=usage, model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
@ -132,12 +132,12 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
return entity
|
||||
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
|
||||
# Returned values are a list of floats, or a list[list[floats]]
|
||||
# Returned values are a list of floats, or aList[List[floats]]
|
||||
# (depending on if you sent a string or a list of string,
|
||||
# and if the automatic reduction, usually mean_pooling for instance was applied for you or not.
|
||||
# This should be explained on the model's README.)
|
||||
@staticmethod
|
||||
def _mean_pooling(embeddings: list) -> list[float]:
|
||||
def _mean_pooling(embeddings: list) -> List[float]:
|
||||
# If automatic reduction by giving model, no need to mean_pooling.
|
||||
# For example one: List[List[float]]
|
||||
if not isinstance(embeddings[0][0], list):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
import httpx
|
||||
|
||||
@ -32,7 +32,7 @@ class JinaRerankModel(RerankModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
docs: List[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
@ -108,7 +108,7 @@ class JinaRerankModel(RerankModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from requests import post
|
||||
|
||||
@ -34,7 +34,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
|
||||
api_base: str = "https://api.jina.ai/v1/embeddings"
|
||||
models: list[str] = [
|
||||
models: List[str] = [
|
||||
"jina-embeddings-v2-base-en",
|
||||
"jina-embeddings-v2-small-en",
|
||||
"jina-embeddings-v2-base-zh",
|
||||
@ -45,7 +45,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -113,7 +113,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -142,7 +142,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError("Invalid api key")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Dict, Generator, List, Type, Union, cast
|
||||
|
||||
from httpx import Timeout
|
||||
from openai import (
|
||||
@ -64,13 +63,13 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
@ -86,14 +85,14 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
) -> int:
|
||||
# tools is not supported yet
|
||||
return self._num_tokens_from_messages(prompt_messages, tools=tools)
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
|
||||
self, messages: List[PromptMessage], tools: List[PromptMessageTool]
|
||||
) -> int:
|
||||
"""
|
||||
Calculate num tokens for baichuan model
|
||||
@ -156,7 +155,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||
def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling
|
||||
|
||||
@ -224,7 +223,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity | None:
|
||||
) -> Union[AIModelEntity, None]:
|
||||
completion_model = None
|
||||
if credentials["completion_type"] == "chat_completion":
|
||||
completion_model = LLMMode.CHAT.value
|
||||
@ -286,13 +285,13 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
kwargs = self._to_client_kwargs(credentials)
|
||||
# init model client
|
||||
client = OpenAI(**kwargs)
|
||||
@ -414,7 +413,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
return message_dict
|
||||
|
||||
def _convert_prompt_message_to_completion_prompts(
|
||||
self, messages: list[PromptMessage]
|
||||
self, messages: List[PromptMessage]
|
||||
) -> str:
|
||||
"""
|
||||
Convert PromptMessage to completion prompts
|
||||
@ -438,7 +437,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
def _handle_completion_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Completion,
|
||||
) -> LLMResult:
|
||||
@ -489,10 +488,10 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
tools: list[PromptMessageTool],
|
||||
tools: List[PromptMessageTool],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
@ -547,10 +546,10 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
def _handle_completion_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Stream[Completion],
|
||||
tools: list[PromptMessageTool],
|
||||
tools: List[PromptMessageTool],
|
||||
) -> Generator:
|
||||
full_response = ""
|
||||
|
||||
@ -613,10 +612,10 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
tools: list[PromptMessageTool],
|
||||
tools: List[PromptMessageTool],
|
||||
) -> Generator:
|
||||
full_response = ""
|
||||
|
||||
@ -691,8 +690,8 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
full_response += delta.delta.content
|
||||
|
||||
def _extract_response_tool_calls(
|
||||
self, response_function_calls: list[FunctionCall]
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
self, response_function_calls: List[FunctionCall]
|
||||
) -> List[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
@ -714,7 +713,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
return tool_calls
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from json import JSONDecodeError, dumps
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from requests import post
|
||||
from yarl import URL
|
||||
@ -42,7 +42,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -121,7 +121,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -138,7 +138,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
def _get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity | None:
|
||||
) -> Union[AIModelEntity, None]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
@ -177,7 +177,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError(f"Invalid credentials: {e}")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
from requests import Response, post
|
||||
|
||||
@ -27,10 +26,10 @@ class MinimaxChatCompletion:
|
||||
model: str,
|
||||
api_key: str,
|
||||
group_id: str,
|
||||
prompt_messages: list[MinimaxMessage],
|
||||
prompt_messages: List[MinimaxMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[dict[str, Any]],
|
||||
stop: list[str] | None,
|
||||
tools: List[Dict[str, Any]],
|
||||
stop: Union[List[str], None],
|
||||
stream: bool,
|
||||
user: str,
|
||||
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
from requests import Response, post
|
||||
|
||||
@ -28,10 +27,10 @@ class MinimaxChatCompletionPro:
|
||||
model: str,
|
||||
api_key: str,
|
||||
group_id: str,
|
||||
prompt_messages: list[MinimaxMessage],
|
||||
prompt_messages: List[MinimaxMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[dict[str, Any]],
|
||||
stop: list[str] | None,
|
||||
tools: List[Dict[str, Any]],
|
||||
stop: Union[List[str], None],
|
||||
stream: bool,
|
||||
user: str,
|
||||
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
|
||||
@ -58,13 +58,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
return self._generate(
|
||||
model,
|
||||
credentials,
|
||||
@ -110,13 +110,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages, tools)
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
|
||||
self, messages: List[PromptMessage], tools: List[PromptMessageTool]
|
||||
) -> int:
|
||||
"""
|
||||
Calculate num tokens for minimax model
|
||||
@ -137,13 +137,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
|
||||
"""
|
||||
@ -227,7 +227,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: MinimaxMessage,
|
||||
) -> LLMResult:
|
||||
@ -250,7 +250,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[MinimaxMessage, None, None],
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
@ -319,7 +319,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -11,11 +11,11 @@ class MinimaxMessage:
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
usage: Dict[str, int] = None
|
||||
stop_reason: str = ""
|
||||
function_call: dict[str, Any] = None
|
||||
function_call: Dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
|
||||
return {
|
||||
"sender_type": "BOT",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from json import dumps
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from requests import post
|
||||
|
||||
@ -44,7 +44,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -103,7 +103,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -146,7 +146,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
raise InternalServerError(msg)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
@ -16,10 +15,10 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -42,7 +42,7 @@ class ModelProviderExtension(BaseModel):
|
||||
class ModelProviderFactory:
|
||||
# init cache provider by default
|
||||
init_cache: bool = False
|
||||
model_provider_extensions: dict[str, ModelProviderExtension] = None
|
||||
model_provider_extensions: Dict[str, ModelProviderExtension] = None
|
||||
|
||||
def __init__(self, init_cache: bool = False) -> None:
|
||||
# for cache in memory
|
||||
@ -51,7 +51,7 @@ class ModelProviderFactory:
|
||||
|
||||
def get_providers(
|
||||
self, provider_name: Union[str, set] = ""
|
||||
) -> list[ProviderEntity]:
|
||||
) -> List[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
:return: list of providers
|
||||
@ -159,8 +159,8 @@ class ModelProviderFactory:
|
||||
self,
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_configs: Optional[list[ProviderConfig]] = None,
|
||||
) -> list[SimpleProviderEntity]:
|
||||
provider_configs: Optional[List[ProviderConfig]] = None,
|
||||
) -> List[SimpleProviderEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
@ -234,7 +234,7 @@ class ModelProviderFactory:
|
||||
|
||||
return model_provider_instance
|
||||
|
||||
def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
|
||||
def _get_model_provider_map(self) -> Dict[str, ModelProviderExtension]:
|
||||
if self.model_provider_extensions:
|
||||
return self.model_provider_extensions
|
||||
|
||||
@ -254,7 +254,7 @@ class ModelProviderFactory:
|
||||
position_map = get_position_map(model_providers_path)
|
||||
|
||||
# traverse all model_provider_dir_paths
|
||||
model_providers: list[ModelProviderExtension] = []
|
||||
model_providers: List[ModelProviderExtension] = []
|
||||
for model_provider_dir_path in model_provider_dir_paths:
|
||||
# get model_provider dir name
|
||||
model_provider_name = os.path.basename(model_provider_dir_path)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
@ -16,10 +16,10 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Dict, List, Optional, Type, Union, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
@ -63,10 +63,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -97,8 +97,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -159,9 +159,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -258,7 +258,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
credentials: dict,
|
||||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
@ -310,7 +310,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
credentials: dict,
|
||||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm completion stream response
|
||||
@ -462,7 +462,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
|
||||
def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
Calculate num tokens.
|
||||
|
||||
@ -700,7 +700,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
return entity
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy as np
|
||||
@ -48,7 +48,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -123,7 +123,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
embeddings=batched_embeddings, usage=usage, model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
@ -211,7 +211,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
|
||||
@ -35,7 +37,7 @@ class _CommonOpenAI:
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI, Stream
|
||||
@ -72,10 +72,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -128,13 +128,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
@ -194,12 +194,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
user: Union[str, None] = None,
|
||||
response_format: str = "JSON",
|
||||
) -> None:
|
||||
"""
|
||||
@ -242,12 +242,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
user: Union[str, None] = None,
|
||||
response_format: str = "JSON",
|
||||
) -> None:
|
||||
"""
|
||||
@ -287,8 +287,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -366,7 +366,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def remote_models(self, credentials: dict) -> list[AIModelEntity]:
|
||||
def remote_models(self, credentials: dict) -> List[AIModelEntity]:
|
||||
"""
|
||||
Return remote models if credentials are provided.
|
||||
|
||||
@ -424,9 +424,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -479,7 +479,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Completion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
@ -528,7 +528,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[Completion],
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm completion stream response
|
||||
@ -599,10 +599,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -676,8 +676,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: ChatCompletion,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
@ -740,8 +740,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: Stream[ChatCompletionChunk],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
@ -851,8 +851,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
|
||||
def _extract_response_tool_calls(
|
||||
self,
|
||||
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
response_tool_calls: List[
|
||||
Union[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
|
||||
],
|
||||
) -> List[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
@ -877,7 +879,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
return tool_calls
|
||||
|
||||
def _extract_response_function_call(
|
||||
self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
|
||||
self, response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall]
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract function call from response
|
||||
@ -967,7 +969,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(
|
||||
self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None
|
||||
self, model: str, text: str, tools: Optional[List[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
@ -992,8 +994,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
@ -1068,7 +1070,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_for_tools(
|
||||
self, encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
|
||||
self, encoding: tiktoken.Encoding, tools: List[PromptMessageTool]
|
||||
) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling with tiktoken package.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types import ModerationCreateResponse
|
||||
@ -82,7 +82,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _moderation_invoke(
|
||||
self, model: str, client: OpenAI, texts: list[str]
|
||||
self, model: str, client: OpenAI, texts: List[str]
|
||||
) -> ModerationCreateResponse:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
@ -31,7 +31,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -58,7 +58,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
@ -89,8 +89,8 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
@ -118,7 +118,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
|
||||
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -167,9 +167,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
client: OpenAI,
|
||||
texts: Union[list[str], str],
|
||||
texts: Union[List[str], str],
|
||||
extra_model_kwargs: dict,
|
||||
) -> tuple[list[list[float]], int]:
|
||||
) -> Tuple[List[List[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
import requests
|
||||
|
||||
from model_providers.core.model_runtime.errors.invoke import (
|
||||
@ -12,7 +14,7 @@ from model_providers.core.model_runtime.errors.invoke import (
|
||||
|
||||
class _CommonOAI_API_Compat:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Union, cast
|
||||
from typing import List, Optional, Union, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
@ -61,10 +61,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -98,8 +98,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -282,10 +282,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -384,7 +384,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -516,7 +516,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
response_json = response.json()
|
||||
|
||||
@ -649,7 +649,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(
|
||||
self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None
|
||||
self, model: str, text: str, tools: Optional[List[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Approximate num tokens for model with gpt2 tokenizer.
|
||||
@ -669,8 +669,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Approximate num tokens with GPT2 tokenizer.
|
||||
@ -722,7 +722,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||
def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling with tiktoken package.
|
||||
|
||||
@ -769,8 +769,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
return num_tokens
|
||||
|
||||
def _extract_response_tool_calls(
|
||||
self, response_tool_calls: list[dict]
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
self, response_tool_calls: List[dict]
|
||||
) -> List[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy as np
|
||||
@ -40,7 +40,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -131,7 +131,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
embeddings=batched_embeddings, usage=usage, model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Dict, Generator, List, Type, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
@ -54,13 +54,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
return self._generate(
|
||||
model,
|
||||
credentials,
|
||||
@ -105,13 +105,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages, tools)
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
|
||||
self, messages: List[PromptMessage], tools: List[PromptMessageTool]
|
||||
) -> int:
|
||||
"""
|
||||
Calculate num tokens for OpenLLM model
|
||||
@ -124,13 +124,13 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
client = OpenLLMGenerate()
|
||||
response = client.generate(
|
||||
model_name=model,
|
||||
@ -183,7 +183,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: OpenLLMGenerateMessage,
|
||||
) -> LLMResult:
|
||||
@ -206,7 +206,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[OpenLLMGenerateMessage, None, None],
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
@ -249,7 +249,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity | None:
|
||||
) -> Union[AIModelEntity, None]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
@ -298,7 +298,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
return entity
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -20,10 +20,10 @@ class OpenLLMGenerateMessage:
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
usage: Dict[str, int] = None
|
||||
stop_reason: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
@ -40,9 +40,9 @@ class OpenLLMGenerate:
|
||||
server_url: str,
|
||||
model_name: str,
|
||||
stream: bool,
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str],
|
||||
prompt_messages: list[OpenLLMGenerateMessage],
|
||||
model_parameters: Dict[str, Any],
|
||||
stop: List[str],
|
||||
prompt_messages: List[OpenLLMGenerateMessage],
|
||||
user: str,
|
||||
) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]:
|
||||
if not server_url:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from json import dumps
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from requests import post
|
||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||
@ -35,7 +35,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
@ -89,7 +89,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -118,7 +118,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError("Invalid server_url")
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from model_providers.core.model_runtime.errors.invoke import (
|
||||
@ -8,5 +10,5 @@ from model_providers.core.model_runtime.errors.invoke import (
|
||||
|
||||
class _CommonReplicate:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
return {InvokeBadRequestError: [ReplicateError, ModelError]}
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from replicate import Client as ReplicateClient
|
||||
from replicate.exceptions import ReplicateError
|
||||
@ -43,10 +42,10 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -86,8 +85,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
@ -167,7 +166,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
@classmethod
|
||||
def _get_customizable_model_parameter_rules(
|
||||
cls, model: str, credentials: dict
|
||||
) -> list[ParameterRule]:
|
||||
) -> List[ParameterRule]:
|
||||
version = credentials["model_version"]
|
||||
|
||||
client = ReplicateClient(
|
||||
@ -215,8 +214,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prediction: Prediction,
|
||||
stop: list[str],
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: List[str],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
index = -1
|
||||
current_completion: str = ""
|
||||
@ -281,8 +280,8 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prediction: Prediction,
|
||||
stop: list[str],
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: List[str],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
current_completion: str = ""
|
||||
stop_condition_reached = False
|
||||
@ -332,7 +331,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
elif param_type == "string":
|
||||
return "string"
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from replicate import Client as ReplicateClient
|
||||
|
||||
@ -31,7 +31,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
user: Optional[str] = None,
|
||||
) -> TextEmbeddingResult:
|
||||
client = ReplicateClient(
|
||||
@ -52,7 +52,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
|
||||
return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: List[str]) -> int:
|
||||
num_tokens = 0
|
||||
for text in texts:
|
||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
@ -124,8 +124,8 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
client: ReplicateClient,
|
||||
replicate_model_version: str,
|
||||
text_input_key: str,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
texts: List[str],
|
||||
) -> List[List[float]]:
|
||||
if text_input_key in ("text", "inputs"):
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
@ -37,10 +37,10 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -66,8 +66,8 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -109,9 +109,9 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -171,13 +171,12 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
client: SparkLLMClient,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
@ -222,7 +221,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
client: SparkLLMClient,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -300,7 +299,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
@ -317,7 +316,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
|
||||
return text.rstrip()
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||
from model_providers.core.model_runtime.entities.message_entities import (
|
||||
@ -21,10 +21,10 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -50,10 +50,10 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -81,8 +81,8 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from model_providers.core.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
|
||||
@ -11,7 +13,7 @@ class _CommonTongyi:
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Tongyi
|
||||
@ -8,7 +8,7 @@ from langchain.schema import Generation, LLMResult
|
||||
|
||||
class EnhanceTongyi(Tongyi):
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {"top_p": self.top_p, "api_key": self.dashscope_api_key}
|
||||
|
||||
@ -16,13 +16,13 @@ class EnhanceTongyi(Tongyi):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
params: dict[str, Any] = {
|
||||
params: Dict[str, Any] = {
|
||||
**{"model": self.model_name},
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from dashscope import get_tokenizer
|
||||
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
||||
@ -50,10 +50,10 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -79,14 +79,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Wrapper for code block mode
|
||||
"""
|
||||
@ -174,8 +174,8 @@ if you are not sure about the structure.
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@ -220,9 +220,9 @@ if you are not sure about the structure.
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -289,7 +289,7 @@ if you are not sure about the structure.
|
||||
model: str,
|
||||
credentials: dict,
|
||||
response: DashScopeAPIResponse,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@ -326,7 +326,7 @@ if you are not sure about the structure.
|
||||
model: str,
|
||||
credentials: dict,
|
||||
responses: Generator,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -412,7 +412,7 @@ if you are not sure about the structure.
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
@ -429,8 +429,8 @@ if you are not sure about the structure.
|
||||
return text.rstrip()
|
||||
|
||||
def _convert_prompt_messages_to_tongyi_messages(
|
||||
self, prompt_messages: list[PromptMessage]
|
||||
) -> list[dict]:
|
||||
self, prompt_messages: List[PromptMessage]
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Convert prompt messages to tongyi messages
|
||||
|
||||
@ -466,7 +466,7 @@ if you are not sure about the structure.
|
||||
return tongyi_messages
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -4,7 +4,7 @@ from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydub import AudioSegment
|
||||
|
||||
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from json import dumps, loads
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
from typing import Any, Dict, Generator, List, Union
|
||||
|
||||
from requests import Response, post
|
||||
|
||||
@ -19,7 +18,7 @@ from model_providers.core.model_runtime.model_providers.wenxin.llm.ernie_bot_err
|
||||
)
|
||||
|
||||
# map api_key to access_token
|
||||
baidu_access_tokens: dict[str, "BaiduAccessToken"] = {}
|
||||
baidu_access_tokens: Dict[str, "BaiduAccessToken"] = {}
|
||||
baidu_access_tokens_lock = Lock()
|
||||
|
||||
|
||||
@ -118,10 +117,10 @@ class ErnieMessage:
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
usage: Dict[str, int] = None
|
||||
stop_reason: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
@ -156,11 +155,11 @@ class ErnieBotModel:
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[ErnieMessage],
|
||||
parameters: dict[str, Any],
|
||||
messages: List[ErnieMessage],
|
||||
parameters: Dict[str, Any],
|
||||
timeout: int,
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
tools: List[PromptMessageTool],
|
||||
stop: List[str],
|
||||
user: str,
|
||||
) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]:
|
||||
# check parameters
|
||||
@ -243,15 +242,15 @@ class ErnieBotModel:
|
||||
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
||||
return token.access_token
|
||||
|
||||
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
|
||||
def _copy_messages(self, messages: List[ErnieMessage]) -> List[ErnieMessage]:
|
||||
return [ErnieMessage(message.content, message.role) for message in messages]
|
||||
|
||||
def _check_parameters(
|
||||
self,
|
||||
model: str,
|
||||
parameters: dict[str, Any],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
parameters: Dict[str, Any],
|
||||
tools: List[PromptMessageTool],
|
||||
stop: List[str],
|
||||
) -> None:
|
||||
if model not in self.api_bases:
|
||||
raise BadRequestError(f"Invalid model: {model}")
|
||||
@ -276,13 +275,13 @@ class ErnieBotModel:
|
||||
def _build_request_body(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[ErnieMessage],
|
||||
messages: List[ErnieMessage],
|
||||
stream: bool,
|
||||
parameters: dict[str, Any],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
parameters: Dict[str, Any],
|
||||
tools: List[PromptMessageTool],
|
||||
stop: List[str],
|
||||
user: str,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
# if model in self.function_calling_supports:
|
||||
# return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user)
|
||||
return self._build_chat_request_body(
|
||||
@ -292,13 +291,13 @@ class ErnieBotModel:
|
||||
def _build_function_calling_request_body(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[ErnieMessage],
|
||||
messages: List[ErnieMessage],
|
||||
stream: bool,
|
||||
parameters: dict[str, Any],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
parameters: Dict[str, Any],
|
||||
tools: List[PromptMessageTool],
|
||||
stop: List[str],
|
||||
user: str,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
if len(messages) % 2 == 0:
|
||||
raise BadRequestError("The number of messages should be odd.")
|
||||
if messages[0].role == "function":
|
||||
@ -311,12 +310,12 @@ class ErnieBotModel:
|
||||
def _build_chat_request_body(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[ErnieMessage],
|
||||
messages: List[ErnieMessage],
|
||||
stream: bool,
|
||||
parameters: dict[str, Any],
|
||||
stop: list[str],
|
||||
parameters: Dict[str, Any],
|
||||
stop: List[str],
|
||||
user: str,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
if len(messages) == 0:
|
||||
raise BadRequestError("The number of messages should not be zero.")
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Dict, Generator, List, Optional, Type, Union, cast
|
||||
|
||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||
@ -59,13 +58,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
@ -81,13 +80,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[List[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
callbacks: List[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
@ -140,12 +139,12 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
user: Union[str, None] = None,
|
||||
response_format: str = "JSON",
|
||||
) -> None:
|
||||
"""
|
||||
@ -187,15 +186,15 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
) -> int:
|
||||
# tools is not supported yet
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
messages: List[PromptMessage],
|
||||
) -> int:
|
||||
"""Calculate num tokens for baichuan model"""
|
||||
|
||||
@ -234,13 +233,13 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
instance = ErnieBotModel(
|
||||
api_key=credentials["api_key"],
|
||||
secret_key=credentials["secret_key"],
|
||||
@ -304,7 +303,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: ErnieMessage,
|
||||
) -> LLMResult:
|
||||
@ -325,7 +324,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[ErnieMessage, None, None],
|
||||
) -> Generator:
|
||||
@ -367,7 +366,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
from typing import Dict, List, Union, cast
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
@ -81,13 +81,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
invoke LLM
|
||||
|
||||
@ -168,8 +168,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
) -> int:
|
||||
"""
|
||||
get number of tokens
|
||||
@ -181,8 +181,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
is_completion_model: bool = False,
|
||||
) -> int:
|
||||
def tokens(text: str):
|
||||
@ -240,7 +240,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||
def _num_tokens_for_tools(self, tools: List[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling
|
||||
|
||||
@ -284,7 +284,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
|
||||
def _convert_prompt_message_to_text(self, message: List[PromptMessage]) -> str:
|
||||
"""
|
||||
convert prompt message to text
|
||||
"""
|
||||
@ -337,7 +337,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity | None:
|
||||
) -> Union[AIModelEntity, None]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
@ -412,14 +412,14 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: List[PromptMessage],
|
||||
model_parameters: dict,
|
||||
extra_model_kwargs: XinferenceModelExtraParameter,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
tools: Union[List[PromptMessageTool], None] = None,
|
||||
stop: Union[List[str], None] = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
user: Union[str, None] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
generate text from LLM
|
||||
|
||||
@ -525,8 +525,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _extract_response_tool_calls(
|
||||
self,
|
||||
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
response_tool_calls: Union[
|
||||
List[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]
|
||||
],
|
||||
) -> List[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
@ -551,7 +553,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
return tool_calls
|
||||
|
||||
def _extract_response_function_call(
|
||||
self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall
|
||||
self, response_function_call: Union[FunctionCall, ChoiceDeltaFunctionCall]
|
||||
) -> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract function call from response
|
||||
@ -576,8 +578,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
resp: ChatCompletion,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
@ -633,8 +635,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
resp: Iterator[ChatCompletionChunk],
|
||||
) -> Generator:
|
||||
"""
|
||||
@ -721,8 +723,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
resp: Completion,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
@ -765,8 +767,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
resp: Iterator[Completion],
|
||||
) -> Generator:
|
||||
"""
|
||||
@ -834,7 +836,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
full_response += delta.text
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from xinference_client.client.restful.restful_client import (
|
||||
Client,
|
||||
@ -41,7 +41,7 @@ class XinferenceRerankModel(RerankModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
docs: List[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
@ -133,7 +133,7 @@ class XinferenceRerankModel(RerankModel):
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> Dict[Type[InvokeError], List[Type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
@ -152,7 +152,7 @@ class XinferenceRerankModel(RerankModel):
|
||||
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity | None:
|
||||
) -> Union[AIModelEntity, None]:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user