mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 07:23:29 +08:00
格式化代码
This commit is contained in:
parent
9818bd2a88
commit
4c040a49be
@ -1,17 +1,23 @@
|
|||||||
from chatchat.configs import MODEL_PLATFORMS
|
from chatchat.configs import MODEL_PLATFORMS
|
||||||
|
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
def _to_custom_provide_configuration():
|
def _to_custom_provide_configuration():
|
||||||
provider_name_to_provider_records_dict = {}
|
provider_name_to_provider_records_dict = {}
|
||||||
provider_name_to_provider_model_records_dict = {}
|
provider_name_to_provider_model_records_dict = {}
|
||||||
return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict
|
return (
|
||||||
|
provider_name_to_provider_records_dict,
|
||||||
|
provider_name_to_provider_model_records_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 基于配置管理器创建的模型实例
|
# 基于配置管理器创建的模型实例
|
||||||
provider_manager = ModelManager(
|
provider_manager = ModelManager(
|
||||||
provider_name_to_provider_records_dict={
|
provider_name_to_provider_records_dict={
|
||||||
'openai': {
|
"openai": {
|
||||||
'openai_api_key': "sk-4M9LYF",
|
"openai_api_key": "sk-4M9LYF",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
provider_name_to_provider_model_records_dict={}
|
provider_name_to_provider_model_records_dict={},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,51 +1,58 @@
|
|||||||
import os
|
import os
|
||||||
from typing import cast, Generator
|
from typing import Generator, cast
|
||||||
|
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# 基于配置管理器创建的模型实例
|
# 基于配置管理器创建的模型实例
|
||||||
provider_manager = ModelManager(
|
provider_manager = ModelManager(
|
||||||
provider_name_to_provider_records_dict={
|
provider_name_to_provider_records_dict={
|
||||||
'openai': {
|
"openai": {
|
||||||
'openai_api_key': "sk-4M9LYF",
|
"openai_api_key": "sk-4M9LYF",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
provider_name_to_provider_model_records_dict={}
|
provider_name_to_provider_model_records_dict={},
|
||||||
)
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Invoke model
|
# Invoke model
|
||||||
model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4')
|
model_instance = provider_manager.get_model_instance(
|
||||||
|
provider="openai", model_type=ModelType.LLM, model="gpt-4"
|
||||||
|
)
|
||||||
|
|
||||||
response = model_instance.invoke_llm(
|
response = model_instance.invoke_llm(
|
||||||
|
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||||
prompt_messages=[
|
|
||||||
UserPromptMessage(
|
|
||||||
content='北京今天的天气怎么样'
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model_parameters={
|
model_parameters={
|
||||||
'temperature': 0.7,
|
"temperature": 0.7,
|
||||||
'top_p': 1.0,
|
"top_p": 1.0,
|
||||||
'top_k': 1,
|
"top_k": 1,
|
||||||
'plugin_web_search': True,
|
"plugin_web_search": True,
|
||||||
},
|
},
|
||||||
stop=['you'],
|
stop=["you"],
|
||||||
stream=True,
|
stream=True,
|
||||||
user="abc-123"
|
user="abc-123",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, Generator)
|
assert isinstance(response, Generator)
|
||||||
total_message = ''
|
total_message = ""
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
assert isinstance(chunk, LLMResultChunk)
|
assert isinstance(chunk, LLMResultChunk)
|
||||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||||
total_message += chunk.delta.message.content
|
total_message += chunk.delta.message.content
|
||||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
assert (
|
||||||
|
len(chunk.delta.message.content) > 0
|
||||||
|
if not chunk.delta.finish_reason
|
||||||
|
else True
|
||||||
|
)
|
||||||
print(total_message)
|
print(total_message)
|
||||||
assert '参考资料' in total_message
|
assert "参考资料" in total_message
|
||||||
|
|||||||
@ -1,60 +1,58 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
from typing import Optional, Any, Dict
|
|
||||||
|
|
||||||
from fastapi import (APIRouter,
|
|
||||||
FastAPI,
|
|
||||||
HTTPException,
|
|
||||||
Response,
|
|
||||||
Request,
|
|
||||||
status
|
|
||||||
)
|
|
||||||
import logging
|
|
||||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
|
||||||
import json
|
import json
|
||||||
import pprint
|
import logging
|
||||||
import tiktoken
|
|
||||||
from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest, \
|
|
||||||
ChatCompletionResponse, ModelList, EmbeddingsResponse, ChatCompletionStreamResponse, FunctionAvailable
|
|
||||||
from uvicorn import Config, Server
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import pprint
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import UserPromptMessage
|
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||||
|
from model_providers.core.bootstrap.openai_protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
EmbeddingsRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
FunctionAvailable,
|
||||||
|
ModelList,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
||||||
from model_providers.core.utils.generic import dictify, jsonify
|
|
||||||
|
|
||||||
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.utils.generic import dictify, jsonify
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def create_stream_chat_completion(model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest):
|
async def create_stream_chat_completion(
|
||||||
|
model_type_instance: LargeLanguageModel, chat_request: ChatCompletionRequest
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
||||||
response = model_type_instance.invoke(
|
response = model_type_instance.invoke(
|
||||||
model=chat_request.model,
|
model=chat_request.model,
|
||||||
credentials={
|
credentials={
|
||||||
'openai_api_key': "sk-",
|
"openai_api_key": "sk-",
|
||||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||||
},
|
|
||||||
prompt_messages=[
|
|
||||||
UserPromptMessage(
|
|
||||||
content='北京今天的天气怎么样'
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model_parameters={
|
|
||||||
**chat_request.to_model_parameters_dict()
|
|
||||||
},
|
},
|
||||||
|
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||||
|
model_parameters={**chat_request.to_model_parameters_dict()},
|
||||||
stop=chat_request.stop,
|
stop=chat_request.stop,
|
||||||
stream=chat_request.stream,
|
stream=chat_request.stream,
|
||||||
user="abc-123"
|
user="abc-123",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -81,7 +79,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
host = cfg.get("host", "127.0.0.1")
|
host = cfg.get("host", "127.0.0.1")
|
||||||
port = cfg.get("port", 20000)
|
port = cfg.get("port", 20000)
|
||||||
|
|
||||||
logger.info(f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}")
|
logger.info(
|
||||||
|
f"Starting openai Bootstrap Server Lifecycle at endpoint: http://{host}:{port}"
|
||||||
|
)
|
||||||
return cls(host=host, port=port)
|
return cls(host=host, port=port)
|
||||||
|
|
||||||
def serve(self, logging_conf: Optional[dict] = None):
|
def serve(self, logging_conf: Optional[dict] = None):
|
||||||
@ -140,8 +140,12 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
async def list_models(self, request: Request):
|
async def list_models(self, request: Request):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
|
async def create_embeddings(
|
||||||
logger.info(f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}")
|
self, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}"
|
||||||
|
)
|
||||||
if os.environ["API_KEY"] is None:
|
if os.environ["API_KEY"] is None:
|
||||||
authorization = request.headers.get("Authorization")
|
authorization = request.headers.get("Authorization")
|
||||||
authorization = authorization.split("Bearer ")[-1]
|
authorization = authorization.split("Bearer ")[-1]
|
||||||
@ -171,42 +175,41 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
)
|
)
|
||||||
return EmbeddingsResponse(**dictify(response))
|
return EmbeddingsResponse(**dictify(response))
|
||||||
|
|
||||||
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
|
async def create_chat_completion(
|
||||||
logger.info(f"Received chat completion request: {pprint.pformat(chat_request.dict())}")
|
self, request: Request, chat_request: ChatCompletionRequest
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Received chat completion request: {pprint.pformat(chat_request.dict())}"
|
||||||
|
)
|
||||||
if os.environ["API_KEY"] is None:
|
if os.environ["API_KEY"] is None:
|
||||||
authorization = request.headers.get("Authorization")
|
authorization = request.headers.get("Authorization")
|
||||||
authorization = authorization.split("Bearer ")[-1]
|
authorization = authorization.split("Bearer ")[-1]
|
||||||
else:
|
else:
|
||||||
authorization = os.environ["API_KEY"]
|
authorization = os.environ["API_KEY"]
|
||||||
model_provider_factory.get_providers(provider_name='openai')
|
model_provider_factory.get_providers(provider_name="openai")
|
||||||
provider_instance = model_provider_factory.get_provider_instance('openai')
|
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||||
if chat_request.stream:
|
if chat_request.stream:
|
||||||
generator = create_stream_chat_completion(model_type_instance, chat_request)
|
generator = create_stream_chat_completion(model_type_instance, chat_request)
|
||||||
return EventSourceResponse(generator, media_type="text/event-stream")
|
return EventSourceResponse(generator, media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
|
|
||||||
response = model_type_instance.invoke(
|
response = model_type_instance.invoke(
|
||||||
model='gpt-4',
|
model="gpt-4",
|
||||||
credentials={
|
credentials={
|
||||||
'openai_api_key': "sk-",
|
"openai_api_key": "sk-",
|
||||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||||
},
|
},
|
||||||
prompt_messages=[
|
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||||
UserPromptMessage(
|
|
||||||
content='北京今天的天气怎么样'
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model_parameters={
|
model_parameters={
|
||||||
'temperature': 0.7,
|
"temperature": 0.7,
|
||||||
'top_p': 1.0,
|
"top_p": 1.0,
|
||||||
'top_k': 1,
|
"top_k": 1,
|
||||||
'plugin_web_search': True,
|
"plugin_web_search": True,
|
||||||
},
|
},
|
||||||
stop=['you'],
|
stop=["you"],
|
||||||
stream=False,
|
stream=False,
|
||||||
user="abc-123"
|
user="abc-123",
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_response = ChatCompletionResponse(**dictify(response))
|
chat_response = ChatCompletionResponse(**dictify(response))
|
||||||
@ -215,15 +218,19 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
cfg: Dict, logging_conf: Optional[dict] = None,
|
cfg: Dict,
|
||||||
started_event: mp.Event = None,
|
logging_conf: Optional[dict] = None,
|
||||||
|
started_event: mp.Event = None,
|
||||||
):
|
):
|
||||||
logging.config.dictConfig(logging_conf) # type: ignore
|
logging.config.dictConfig(logging_conf) # type: ignore
|
||||||
try:
|
try:
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
# 跳过键盘中断,使用xoscar的信号处理
|
# 跳过键盘中断,使用xoscar的信号处理
|
||||||
signal.signal(signal.SIGINT, lambda *_: None)
|
signal.signal(signal.SIGINT, lambda *_: None)
|
||||||
api = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg=cfg.get("run_openai_api", {}))
|
api = RESTFulOpenAIBootstrapBaseWeb.from_config(
|
||||||
|
cfg=cfg.get("run_openai_api", {})
|
||||||
|
)
|
||||||
api.set_app_event(started_event=started_event)
|
api.set_app_event(started_event=started_event)
|
||||||
api.serve(logging_conf=logging_conf)
|
api.serve(logging_conf=logging_conf)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
|
|
||||||
from model_providers.core.bootstrap.base import Bootstrap, OpenAIBootstrapBaseWeb
|
from model_providers.core.bootstrap.base import Bootstrap, OpenAIBootstrapBaseWeb
|
||||||
from model_providers.core.bootstrap.bootstrap_register import bootstrap_register
|
from model_providers.core.bootstrap.bootstrap_register import bootstrap_register
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"bootstrap_register",
|
"bootstrap_register",
|
||||||
"Bootstrap",
|
"Bootstrap",
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
|
||||||
class Bootstrap:
|
class Bootstrap:
|
||||||
|
|
||||||
"""最大的任务队列"""
|
"""最大的任务队列"""
|
||||||
|
|
||||||
_MAX_ONGOING_TASKS: int = 1
|
_MAX_ONGOING_TASKS: int = 1
|
||||||
|
|
||||||
"""任务队列"""
|
"""任务队列"""
|
||||||
@ -37,7 +39,6 @@ class Bootstrap:
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIBootstrapBaseWeb(Bootstrap):
|
class OpenAIBootstrapBaseWeb(Bootstrap):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -46,9 +47,13 @@ class OpenAIBootstrapBaseWeb(Bootstrap):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create_embeddings(self, request: Request, embeddings_request: EmbeddingsRequest):
|
async def create_embeddings(
|
||||||
|
self, request: Request, embeddings_request: EmbeddingsRequest
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create_chat_completion(self, request: Request, chat_request: ChatCompletionRequest):
|
async def create_chat_completion(
|
||||||
|
self, request: Request, chat_request: ChatCompletionRequest
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -5,6 +5,7 @@ class BootstrapRegister:
|
|||||||
"""
|
"""
|
||||||
注册管理器
|
注册管理器
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mapping = {
|
mapping = {
|
||||||
"bootstrap": {},
|
"bootstrap": {},
|
||||||
}
|
}
|
||||||
@ -48,4 +49,3 @@ class BootstrapRegister:
|
|||||||
|
|
||||||
|
|
||||||
bootstrap_register = BootstrapRegister()
|
bootstrap_register = BootstrapRegister()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, Field, root_validator
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
@ -86,13 +87,15 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
top_k: Optional[float] = None
|
top_k: Optional[float] = None
|
||||||
n: int = 1
|
n: int = 1
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = (None,)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
def to_model_parameters_dict(self, *args, **kwargs):
|
def to_model_parameters_dict(self, *args, **kwargs):
|
||||||
# 调用父类的to_dict方法,并排除tools字段
|
# 调用父类的to_dict方法,并排除tools字段
|
||||||
helper.dump_model
|
helper.dump_model
|
||||||
return super().dict(exclude={'tools','messages','functions','function_call'}, *args, **kwargs)
|
return super().dict(
|
||||||
|
exclude={"tools", "messages", "functions", "function_call"}, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from enum import Enum
|
|||||||
|
|
||||||
|
|
||||||
class PlanningStrategy(Enum):
|
class PlanningStrategy(Enum):
|
||||||
ROUTER = 'router'
|
ROUTER = "router"
|
||||||
REACT_ROUTER = 'react_router'
|
REACT_ROUTER = "react_router"
|
||||||
REACT = 'react'
|
REACT = "react"
|
||||||
FUNCTION_CALL = 'function_call'
|
FUNCTION_CALL = "function_call"
|
||||||
|
|||||||
@ -5,7 +5,9 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from model_providers.core.entities.provider_configuration import ProviderModelBundle
|
from model_providers.core.entities.provider_configuration import ProviderModelBundle
|
||||||
from model_providers.core.file.file_obj import FileObj
|
from model_providers.core.file.file_obj import FileObj
|
||||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessageRole
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
PromptMessageRole,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
|
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
|
|
||||||
|
|
||||||
@ -13,6 +15,7 @@ class ModelConfigEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model Config Entity.
|
Model Config Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
model: str
|
model: str
|
||||||
model_schema: AIModelEntity
|
model_schema: AIModelEntity
|
||||||
@ -27,6 +30,7 @@ class AdvancedChatMessageEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Advanced Chat Message Entity.
|
Advanced Chat Message Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
role: PromptMessageRole
|
role: PromptMessageRole
|
||||||
|
|
||||||
@ -35,6 +39,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Advanced Chat Prompt Template Entity.
|
Advanced Chat Prompt Template Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages: list[AdvancedChatMessageEntity]
|
messages: list[AdvancedChatMessageEntity]
|
||||||
|
|
||||||
|
|
||||||
@ -47,6 +52,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Role Prefix Entity.
|
Role Prefix Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user: str
|
user: str
|
||||||
assistant: str
|
assistant: str
|
||||||
|
|
||||||
@ -64,11 +70,12 @@ class PromptTemplateEntity(BaseModel):
|
|||||||
Prompt Type.
|
Prompt Type.
|
||||||
'simple', 'advanced'
|
'simple', 'advanced'
|
||||||
"""
|
"""
|
||||||
SIMPLE = 'simple'
|
|
||||||
ADVANCED = 'advanced'
|
SIMPLE = "simple"
|
||||||
|
ADVANCED = "advanced"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> 'PromptType':
|
def value_of(cls, value: str) -> "PromptType":
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -78,18 +85,21 @@ class PromptTemplateEntity(BaseModel):
|
|||||||
for mode in cls:
|
for mode in cls:
|
||||||
if mode.value == value:
|
if mode.value == value:
|
||||||
return mode
|
return mode
|
||||||
raise ValueError(f'invalid prompt type value {value}')
|
raise ValueError(f"invalid prompt type value {value}")
|
||||||
|
|
||||||
prompt_type: PromptType
|
prompt_type: PromptType
|
||||||
simple_prompt_template: Optional[str] = None
|
simple_prompt_template: Optional[str] = None
|
||||||
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
|
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
|
||||||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
advanced_completion_prompt_template: Optional[
|
||||||
|
AdvancedCompletionPromptTemplateEntity
|
||||||
|
] = None
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataVariableEntity(BaseModel):
|
class ExternalDataVariableEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
External Data Variable Entity.
|
External Data Variable Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
variable: str
|
variable: str
|
||||||
type: str
|
type: str
|
||||||
config: dict[str, Any] = {}
|
config: dict[str, Any] = {}
|
||||||
@ -105,11 +115,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||||||
Dataset Retrieve Strategy.
|
Dataset Retrieve Strategy.
|
||||||
'single' or 'multiple'
|
'single' or 'multiple'
|
||||||
"""
|
"""
|
||||||
SINGLE = 'single'
|
|
||||||
MULTIPLE = 'multiple'
|
SINGLE = "single"
|
||||||
|
MULTIPLE = "multiple"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> 'RetrieveStrategy':
|
def value_of(cls, value: str) -> "RetrieveStrategy":
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -119,7 +130,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
|||||||
for mode in cls:
|
for mode in cls:
|
||||||
if mode.value == value:
|
if mode.value == value:
|
||||||
return mode
|
return mode
|
||||||
raise ValueError(f'invalid retrieve strategy value {value}')
|
raise ValueError(f"invalid retrieve strategy value {value}")
|
||||||
|
|
||||||
query_variable: Optional[str] = None # Only when app mode is completion
|
query_variable: Optional[str] = None # Only when app mode is completion
|
||||||
|
|
||||||
@ -134,6 +145,7 @@ class DatasetEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Dataset Config Entity.
|
Dataset Config Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataset_ids: list[str]
|
dataset_ids: list[str]
|
||||||
retrieve_config: DatasetRetrieveConfigEntity
|
retrieve_config: DatasetRetrieveConfigEntity
|
||||||
|
|
||||||
@ -142,6 +154,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Sensitive Word Avoidance Entity.
|
Sensitive Word Avoidance Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
config: dict[str, Any] = {}
|
config: dict[str, Any] = {}
|
||||||
|
|
||||||
@ -150,6 +163,7 @@ class TextToSpeechEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Sensitive Word Avoidance Entity.
|
Sensitive Word Avoidance Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
enabled: bool
|
enabled: bool
|
||||||
voice: Optional[str] = None
|
voice: Optional[str] = None
|
||||||
language: Optional[str] = None
|
language: Optional[str] = None
|
||||||
@ -159,6 +173,7 @@ class FileUploadEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
File Upload Entity.
|
File Upload Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_config: Optional[dict[str, Any]] = None
|
image_config: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@ -166,6 +181,7 @@ class AgentToolEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Agent Tool Entity.
|
Agent Tool Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider_type: Literal["builtin", "api"]
|
provider_type: Literal["builtin", "api"]
|
||||||
provider_id: str
|
provider_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
@ -176,6 +192,7 @@ class AgentPromptEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Agent Prompt Entity.
|
Agent Prompt Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
first_prompt: str
|
first_prompt: str
|
||||||
next_iteration: str
|
next_iteration: str
|
||||||
|
|
||||||
@ -189,6 +206,7 @@ class AgentScratchpadUnit(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Action Entity.
|
Action Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
action_name: str
|
action_name: str
|
||||||
action_input: Union[dict, str]
|
action_input: Union[dict, str]
|
||||||
|
|
||||||
@ -208,8 +226,9 @@ class AgentEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Agent Strategy.
|
Agent Strategy.
|
||||||
"""
|
"""
|
||||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
|
||||||
FUNCTION_CALLING = 'function-calling'
|
CHAIN_OF_THOUGHT = "chain-of-thought"
|
||||||
|
FUNCTION_CALLING = "function-calling"
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
model: str
|
model: str
|
||||||
@ -223,6 +242,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
App Orchestration Config Entity.
|
App Orchestration Config Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config: ModelConfigEntity
|
model_config: ModelConfigEntity
|
||||||
prompt_template: PromptTemplateEntity
|
prompt_template: PromptTemplateEntity
|
||||||
external_data_variables: list[ExternalDataVariableEntity] = []
|
external_data_variables: list[ExternalDataVariableEntity] = []
|
||||||
@ -244,13 +264,14 @@ class InvokeFrom(Enum):
|
|||||||
"""
|
"""
|
||||||
Invoke From.
|
Invoke From.
|
||||||
"""
|
"""
|
||||||
SERVICE_API = 'service-api'
|
|
||||||
WEB_APP = 'web-app'
|
SERVICE_API = "service-api"
|
||||||
EXPLORE = 'explore'
|
WEB_APP = "web-app"
|
||||||
DEBUGGER = 'debugger'
|
EXPLORE = "explore"
|
||||||
|
DEBUGGER = "debugger"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> 'InvokeFrom':
|
def value_of(cls, value: str) -> "InvokeFrom":
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -260,7 +281,7 @@ class InvokeFrom(Enum):
|
|||||||
for mode in cls:
|
for mode in cls:
|
||||||
if mode.value == value:
|
if mode.value == value:
|
||||||
return mode
|
return mode
|
||||||
raise ValueError(f'invalid invoke from value {value}')
|
raise ValueError(f"invalid invoke from value {value}")
|
||||||
|
|
||||||
def to_source(self) -> str:
|
def to_source(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -269,21 +290,22 @@ class InvokeFrom(Enum):
|
|||||||
:return: source
|
:return: source
|
||||||
"""
|
"""
|
||||||
if self == InvokeFrom.WEB_APP:
|
if self == InvokeFrom.WEB_APP:
|
||||||
return 'web_app'
|
return "web_app"
|
||||||
elif self == InvokeFrom.DEBUGGER:
|
elif self == InvokeFrom.DEBUGGER:
|
||||||
return 'dev'
|
return "dev"
|
||||||
elif self == InvokeFrom.EXPLORE:
|
elif self == InvokeFrom.EXPLORE:
|
||||||
return 'explore_app'
|
return "explore_app"
|
||||||
elif self == InvokeFrom.SERVICE_API:
|
elif self == InvokeFrom.SERVICE_API:
|
||||||
return 'api'
|
return "api"
|
||||||
|
|
||||||
return 'dev'
|
return "dev"
|
||||||
|
|
||||||
|
|
||||||
class ApplicationGenerateEntity(BaseModel):
|
class ApplicationGenerateEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
Application Generate Entity.
|
Application Generate Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task_id: str
|
task_id: str
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,13 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
@ -16,7 +22,7 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
|||||||
|
|
||||||
|
|
||||||
class PromptMessageFileType(enum.Enum):
|
class PromptMessageFileType(enum.Enum):
|
||||||
IMAGE = 'image'
|
IMAGE = "image"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def value_of(value):
|
def value_of(value):
|
||||||
@ -33,8 +39,8 @@ class PromptMessageFile(BaseModel):
|
|||||||
|
|
||||||
class ImagePromptMessageFile(PromptMessageFile):
|
class ImagePromptMessageFile(PromptMessageFile):
|
||||||
class DETAIL(enum.Enum):
|
class DETAIL(enum.Enum):
|
||||||
LOW = 'low'
|
LOW = "low"
|
||||||
HIGH = 'high'
|
HIGH = "high"
|
||||||
|
|
||||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||||
detail: DETAIL = DETAIL.LOW
|
detail: DETAIL = DETAIL.LOW
|
||||||
@ -55,32 +61,39 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe
|
|||||||
for file in message.files:
|
for file in message.files:
|
||||||
if file.type == PromptMessageFileType.IMAGE:
|
if file.type == PromptMessageFileType.IMAGE:
|
||||||
file = cast(ImagePromptMessageFile, file)
|
file = cast(ImagePromptMessageFile, file)
|
||||||
file_prompt_message_contents.append(ImagePromptMessageContent(
|
file_prompt_message_contents.append(
|
||||||
data=file.data,
|
ImagePromptMessageContent(
|
||||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
data=file.data,
|
||||||
if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
|
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||||
))
|
if file.detail.value == "high"
|
||||||
|
else ImagePromptMessageContent.DETAIL.LOW,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=message.content)]
|
prompt_message_contents = [
|
||||||
|
TextPromptMessageContent(data=message.content)
|
||||||
|
]
|
||||||
prompt_message_contents.extend(file_prompt_message_contents)
|
prompt_message_contents.extend(file_prompt_message_contents)
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(
|
||||||
|
UserPromptMessage(content=prompt_message_contents)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_messages.append(UserPromptMessage(content=message.content))
|
prompt_messages.append(UserPromptMessage(content=message.content))
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_kwargs = {
|
message_kwargs = {"content": message.content}
|
||||||
'content': message.content
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'function_call' in message.additional_kwargs:
|
if "function_call" in message.additional_kwargs:
|
||||||
message_kwargs['tool_calls'] = [
|
message_kwargs["tool_calls"] = [
|
||||||
AssistantPromptMessage.ToolCall(
|
AssistantPromptMessage.ToolCall(
|
||||||
id=message.additional_kwargs['function_call']['id'],
|
id=message.additional_kwargs["function_call"]["id"],
|
||||||
type='function',
|
type="function",
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=message.additional_kwargs['function_call']['name'],
|
name=message.additional_kwargs["function_call"]["name"],
|
||||||
arguments=message.additional_kwargs['function_call']['arguments']
|
arguments=message.additional_kwargs["function_call"][
|
||||||
)
|
"arguments"
|
||||||
|
],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -88,12 +101,16 @@ def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMe
|
|||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
prompt_messages.append(SystemPromptMessage(content=message.content))
|
prompt_messages.append(SystemPromptMessage(content=message.content))
|
||||||
elif isinstance(message, FunctionMessage):
|
elif isinstance(message, FunctionMessage):
|
||||||
prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
|
prompt_messages.append(
|
||||||
|
ToolPromptMessage(content=message.content, tool_call_id=message.name)
|
||||||
|
)
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
|
def prompt_messages_to_lc_messages(
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> list[BaseMessage]:
|
||||||
messages = []
|
messages = []
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
if isinstance(prompt_message, UserPromptMessage):
|
if isinstance(prompt_message, UserPromptMessage):
|
||||||
@ -105,24 +122,24 @@ def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list
|
|||||||
if isinstance(content, TextPromptMessageContent):
|
if isinstance(content, TextPromptMessageContent):
|
||||||
message_contents.append(content.data)
|
message_contents.append(content.data)
|
||||||
elif isinstance(content, ImagePromptMessageContent):
|
elif isinstance(content, ImagePromptMessageContent):
|
||||||
message_contents.append({
|
message_contents.append(
|
||||||
'type': 'image',
|
{
|
||||||
'data': content.data,
|
"type": "image",
|
||||||
'detail': content.detail.value
|
"data": content.data,
|
||||||
})
|
"detail": content.detail.value,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
messages.append(HumanMessage(content=message_contents))
|
messages.append(HumanMessage(content=message_contents))
|
||||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||||
message_kwargs = {
|
message_kwargs = {"content": prompt_message.content}
|
||||||
'content': prompt_message.content
|
|
||||||
}
|
|
||||||
|
|
||||||
if prompt_message.tool_calls:
|
if prompt_message.tool_calls:
|
||||||
message_kwargs['additional_kwargs'] = {
|
message_kwargs["additional_kwargs"] = {
|
||||||
'function_call': {
|
"function_call": {
|
||||||
'id': prompt_message.tool_calls[0].id,
|
"id": prompt_message.tool_calls[0].id,
|
||||||
'name': prompt_message.tool_calls[0].function.name,
|
"name": prompt_message.tool_calls[0].function.name,
|
||||||
'arguments': prompt_message.tool_calls[0].function.arguments
|
"arguments": prompt_message.tool_calls[0].function.arguments,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,6 +147,10 @@ def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list
|
|||||||
elif isinstance(prompt_message, SystemPromptMessage):
|
elif isinstance(prompt_message, SystemPromptMessage):
|
||||||
messages.append(SystemMessage(content=prompt_message.content))
|
messages.append(SystemMessage(content=prompt_message.content))
|
||||||
elif isinstance(prompt_message, ToolPromptMessage):
|
elif isinstance(prompt_message, ToolPromptMessage):
|
||||||
messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
|
messages.append(
|
||||||
|
FunctionMessage(
|
||||||
|
name=prompt_message.tool_call_id, content=prompt_message.content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@ -4,7 +4,10 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType, ProviderModel
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
ModelType,
|
||||||
|
ProviderModel,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
|
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
|
|
||||||
|
|
||||||
@ -12,6 +15,7 @@ class ModelStatus(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for model status.
|
Enum class for model status.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ACTIVE = "active"
|
ACTIVE = "active"
|
||||||
NO_CONFIGURE = "no-configure"
|
NO_CONFIGURE = "no-configure"
|
||||||
QUOTA_EXCEEDED = "quota-exceeded"
|
QUOTA_EXCEEDED = "quota-exceeded"
|
||||||
@ -22,6 +26,7 @@ class SimpleModelProviderEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Simple provider.
|
Simple provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
label: I18nObject
|
label: I18nObject
|
||||||
icon_small: Optional[I18nObject] = None
|
icon_small: Optional[I18nObject] = None
|
||||||
@ -39,7 +44,7 @@ class SimpleModelProviderEntity(BaseModel):
|
|||||||
label=provider_entity.label,
|
label=provider_entity.label,
|
||||||
icon_small=provider_entity.icon_small,
|
icon_small=provider_entity.icon_small,
|
||||||
icon_large=provider_entity.icon_large,
|
icon_large=provider_entity.icon_large,
|
||||||
supported_model_types=provider_entity.supported_model_types
|
supported_model_types=provider_entity.supported_model_types,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -47,6 +52,7 @@ class ModelWithProviderEntity(ProviderModel):
|
|||||||
"""
|
"""
|
||||||
Model with provider entity.
|
Model with provider entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: SimpleModelProviderEntity
|
provider: SimpleModelProviderEntity
|
||||||
status: ModelStatus
|
status: ModelStatus
|
||||||
|
|
||||||
@ -55,6 +61,7 @@ class DefaultModelProviderEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Default model provider entity.
|
Default model provider entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
label: I18nObject
|
label: I18nObject
|
||||||
icon_small: Optional[I18nObject] = None
|
icon_small: Optional[I18nObject] = None
|
||||||
@ -66,6 +73,7 @@ class DefaultModelEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Default model entity.
|
Default model entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
provider: DefaultModelProviderEntity
|
provider: DefaultModelProviderEntity
|
||||||
|
|||||||
@ -7,9 +7,16 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
from model_providers.core.entities.model_entities import (
|
||||||
|
ModelStatus,
|
||||||
|
ModelWithProviderEntity,
|
||||||
|
SimpleModelProviderEntity,
|
||||||
|
)
|
||||||
from model_providers.core.entities.provider_entities import CustomConfiguration
|
from model_providers.core.entities.provider_entities import CustomConfiguration
|
||||||
from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
FetchFrom,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.provider_entities import (
|
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||||
ConfigurateMethod,
|
ConfigurateMethod,
|
||||||
CredentialFormSchema,
|
CredentialFormSchema,
|
||||||
@ -18,7 +25,9 @@ from model_providers.core.model_runtime.entities.provider_entities import (
|
|||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
from model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -27,13 +36,16 @@ class ProviderConfiguration(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider configuration.
|
Model class for provider configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: ProviderEntity
|
provider: ProviderEntity
|
||||||
custom_configuration: CustomConfiguration
|
custom_configuration: CustomConfiguration
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
def get_current_credentials(
|
||||||
|
self, model_type: ModelType, model: str
|
||||||
|
) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Get current credentials.
|
Get current credentials.
|
||||||
|
|
||||||
@ -43,7 +55,10 @@ class ProviderConfiguration(BaseModel):
|
|||||||
"""
|
"""
|
||||||
if self.custom_configuration.models:
|
if self.custom_configuration.models:
|
||||||
for model_configuration in self.custom_configuration.models:
|
for model_configuration in self.custom_configuration.models:
|
||||||
if model_configuration.model_type == model_type and model_configuration.model == model:
|
if (
|
||||||
|
model_configuration.model_type == model_type
|
||||||
|
and model_configuration.model == model
|
||||||
|
):
|
||||||
return model_configuration.credentials
|
return model_configuration.credentials
|
||||||
|
|
||||||
if self.custom_configuration.provider:
|
if self.custom_configuration.provider:
|
||||||
@ -69,8 +84,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
copy_credentials = credentials.copy()
|
copy_credentials = credentials.copy()
|
||||||
return copy_credentials
|
return copy_credentials
|
||||||
|
|
||||||
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
def get_custom_model_credentials(
|
||||||
-> Optional[dict]:
|
self, model_type: ModelType, model: str, obfuscated: bool = False
|
||||||
|
) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Get custom model credentials.
|
Get custom model credentials.
|
||||||
|
|
||||||
@ -83,7 +99,10 @@ class ProviderConfiguration(BaseModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
for model_configuration in self.custom_configuration.models:
|
for model_configuration in self.custom_configuration.models:
|
||||||
if model_configuration.model_type == model_type and model_configuration.model == model:
|
if (
|
||||||
|
model_configuration.model_type == model_type
|
||||||
|
and model_configuration.model == model
|
||||||
|
):
|
||||||
credentials = model_configuration.credentials
|
credentials = model_configuration.credentials
|
||||||
if not obfuscated:
|
if not obfuscated:
|
||||||
return credentials
|
return credentials
|
||||||
@ -113,9 +132,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
# Get model instance of LLM
|
# Get model instance of LLM
|
||||||
return provider_instance.get_model_instance(model_type)
|
return provider_instance.get_model_instance(model_type)
|
||||||
|
|
||||||
def get_provider_model(self, model_type: ModelType,
|
def get_provider_model(
|
||||||
model: str,
|
self, model_type: ModelType, model: str, only_active: bool = False
|
||||||
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
|
) -> Optional[ModelWithProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get provider model.
|
Get provider model.
|
||||||
:param model_type: model type
|
:param model_type: model type
|
||||||
@ -131,8 +150,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_provider_models(self, model_type: Optional[ModelType] = None,
|
def get_provider_models(
|
||||||
only_active: bool = False) -> list[ModelWithProviderEntity]:
|
self, model_type: Optional[ModelType] = None, only_active: bool = False
|
||||||
|
) -> list[ModelWithProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get provider models.
|
Get provider models.
|
||||||
:param model_type: model type
|
:param model_type: model type
|
||||||
@ -148,18 +168,19 @@ class ProviderConfiguration(BaseModel):
|
|||||||
model_types = provider_instance.get_provider_schema().supported_model_types
|
model_types = provider_instance.get_provider_schema().supported_model_types
|
||||||
|
|
||||||
provider_models = self._get_custom_provider_models(
|
provider_models = self._get_custom_provider_models(
|
||||||
model_types=model_types,
|
model_types=model_types, provider_instance=provider_instance
|
||||||
provider_instance=provider_instance
|
|
||||||
)
|
)
|
||||||
if only_active:
|
if only_active:
|
||||||
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
|
provider_models = [
|
||||||
|
m for m in provider_models if m.status == ModelStatus.ACTIVE
|
||||||
|
]
|
||||||
|
|
||||||
# resort provider_models
|
# resort provider_models
|
||||||
return sorted(provider_models, key=lambda x: x.model_type.value)
|
return sorted(provider_models, key=lambda x: x.model_type.value)
|
||||||
|
|
||||||
def _get_custom_provider_models(self,
|
def _get_custom_provider_models(
|
||||||
model_types: list[ModelType],
|
self, model_types: list[ModelType], provider_instance: ModelProvider
|
||||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
) -> list[ModelWithProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get custom provider models.
|
Get custom provider models.
|
||||||
|
|
||||||
@ -189,7 +210,9 @@ class ProviderConfiguration(BaseModel):
|
|||||||
model_properties=m.model_properties,
|
model_properties=m.model_properties,
|
||||||
deprecated=m.deprecated,
|
deprecated=m.deprecated,
|
||||||
provider=SimpleModelProviderEntity(self.provider),
|
provider=SimpleModelProviderEntity(self.provider),
|
||||||
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
status=ModelStatus.ACTIVE
|
||||||
|
if credentials
|
||||||
|
else ModelStatus.NO_CONFIGURE,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -199,15 +222,13 @@ class ProviderConfiguration(BaseModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
custom_model_schema = (
|
custom_model_schema = provider_instance.get_model_instance(
|
||||||
provider_instance.get_model_instance(model_configuration.model_type)
|
model_configuration.model_type
|
||||||
.get_customizable_model_schema_from_credentials(
|
).get_customizable_model_schema_from_credentials(
|
||||||
model_configuration.model,
|
model_configuration.model, model_configuration.credentials
|
||||||
model_configuration.credentials
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(f'get custom model schema failed, {ex}')
|
logger.warning(f"get custom model schema failed, {ex}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not custom_model_schema:
|
if not custom_model_schema:
|
||||||
@ -223,7 +244,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
model_properties=custom_model_schema.model_properties,
|
model_properties=custom_model_schema.model_properties,
|
||||||
deprecated=custom_model_schema.deprecated,
|
deprecated=custom_model_schema.deprecated,
|
||||||
provider=SimpleModelProviderEntity(self.provider),
|
provider=SimpleModelProviderEntity(self.provider),
|
||||||
status=ModelStatus.ACTIVE
|
status=ModelStatus.ACTIVE,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -234,16 +255,18 @@ class ProviderConfigurations(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider configuration dict.
|
Model class for provider configuration dict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
configurations: dict[str, ProviderConfiguration] = {}
|
configurations: dict[str, ProviderConfiguration] = {}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_models(self,
|
def get_models(
|
||||||
provider: Optional[str] = None,
|
self,
|
||||||
model_type: Optional[ModelType] = None,
|
provider: Optional[str] = None,
|
||||||
only_active: bool = False) \
|
model_type: Optional[ModelType] = None,
|
||||||
-> list[ModelWithProviderEntity]:
|
only_active: bool = False,
|
||||||
|
) -> list[ModelWithProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get available models.
|
Get available models.
|
||||||
|
|
||||||
@ -278,7 +301,9 @@ class ProviderConfigurations(BaseModel):
|
|||||||
if provider and provider_configuration.provider.provider != provider:
|
if provider and provider_configuration.provider.provider != provider:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
|
all_models.extend(
|
||||||
|
provider_configuration.get_provider_models(model_type, only_active)
|
||||||
|
)
|
||||||
|
|
||||||
return all_models
|
return all_models
|
||||||
|
|
||||||
@ -310,6 +335,7 @@ class ProviderModelBundle(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Provider model bundle.
|
Provider model bundle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
configuration: ProviderConfiguration
|
configuration: ProviderConfiguration
|
||||||
provider_instance: ModelProvider
|
provider_instance: ModelProvider
|
||||||
model_type_instance: AIModel
|
model_type_instance: AIModel
|
||||||
|
|||||||
@ -12,11 +12,11 @@ class RestrictModel(BaseModel):
|
|||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CustomProviderConfiguration(BaseModel):
|
class CustomProviderConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for provider custom configuration.
|
Model class for provider custom configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
credentials: dict
|
credentials: dict
|
||||||
|
|
||||||
|
|
||||||
@ -24,6 +24,7 @@ class CustomModelConfiguration(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider custom model configuration.
|
Model class for provider custom model configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
credentials: dict
|
credentials: dict
|
||||||
@ -33,5 +34,6 @@ class CustomConfiguration(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider custom configuration.
|
Model class for provider custom configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: Optional[CustomProviderConfiguration] = None
|
provider: Optional[CustomProviderConfiguration] = None
|
||||||
models: list[CustomModelConfiguration] = []
|
models: list[CustomModelConfiguration] = []
|
||||||
|
|||||||
@ -3,13 +3,17 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(Enum):
|
class QueueEvent(Enum):
|
||||||
"""
|
"""
|
||||||
QueueEvent enum
|
QueueEvent enum
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MESSAGE = "message"
|
MESSAGE = "message"
|
||||||
AGENT_MESSAGE = "agent_message"
|
AGENT_MESSAGE = "agent_message"
|
||||||
MESSAGE_REPLACE = "message-replace"
|
MESSAGE_REPLACE = "message-replace"
|
||||||
@ -27,6 +31,7 @@ class AppQueueEvent(BaseModel):
|
|||||||
"""
|
"""
|
||||||
QueueEvent entity
|
QueueEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event: QueueEvent
|
event: QueueEvent
|
||||||
|
|
||||||
|
|
||||||
@ -34,21 +39,25 @@ class QueueMessageEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
QueueMessageEvent entity
|
QueueMessageEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.MESSAGE
|
event = QueueEvent.MESSAGE
|
||||||
chunk: LLMResultChunk
|
chunk: LLMResultChunk
|
||||||
|
|
||||||
|
|
||||||
class QueueAgentMessageEvent(AppQueueEvent):
|
class QueueAgentMessageEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueMessageEvent entity
|
QueueMessageEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.AGENT_MESSAGE
|
event = QueueEvent.AGENT_MESSAGE
|
||||||
chunk: LLMResultChunk
|
chunk: LLMResultChunk
|
||||||
|
|
||||||
|
|
||||||
class QueueMessageReplaceEvent(AppQueueEvent):
|
class QueueMessageReplaceEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueMessageReplaceEvent entity
|
QueueMessageReplaceEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.MESSAGE_REPLACE
|
event = QueueEvent.MESSAGE_REPLACE
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
@ -57,6 +66,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
QueueRetrieverResourcesEvent entity
|
QueueRetrieverResourcesEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.RETRIEVER_RESOURCES
|
event = QueueEvent.RETRIEVER_RESOURCES
|
||||||
retriever_resources: list[dict]
|
retriever_resources: list[dict]
|
||||||
|
|
||||||
@ -65,6 +75,7 @@ class AnnotationReplyEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
AnnotationReplyEvent entity
|
AnnotationReplyEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.ANNOTATION_REPLY
|
event = QueueEvent.ANNOTATION_REPLY
|
||||||
message_annotation_id: str
|
message_annotation_id: str
|
||||||
|
|
||||||
@ -73,28 +84,34 @@ class QueueMessageEndEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
QueueMessageEndEvent entity
|
QueueMessageEndEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.MESSAGE_END
|
event = QueueEvent.MESSAGE_END
|
||||||
llm_result: LLMResult
|
llm_result: LLMResult
|
||||||
|
|
||||||
|
|
||||||
class QueueAgentThoughtEvent(AppQueueEvent):
|
class QueueAgentThoughtEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueAgentThoughtEvent entity
|
QueueAgentThoughtEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.AGENT_THOUGHT
|
event = QueueEvent.AGENT_THOUGHT
|
||||||
agent_thought_id: str
|
agent_thought_id: str
|
||||||
|
|
||||||
|
|
||||||
class QueueMessageFileEvent(AppQueueEvent):
|
class QueueMessageFileEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueAgentThoughtEvent entity
|
QueueAgentThoughtEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.MESSAGE_FILE
|
event = QueueEvent.MESSAGE_FILE
|
||||||
message_file_id: str
|
message_file_id: str
|
||||||
|
|
||||||
|
|
||||||
class QueueErrorEvent(AppQueueEvent):
|
class QueueErrorEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueErrorEvent entity
|
QueueErrorEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.ERROR
|
event = QueueEvent.ERROR
|
||||||
error: Any
|
error: Any
|
||||||
|
|
||||||
@ -103,6 +120,7 @@ class QueuePingEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
QueuePingEvent entity
|
QueuePingEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = QueueEvent.PING
|
event = QueueEvent.PING
|
||||||
|
|
||||||
|
|
||||||
@ -110,10 +128,12 @@ class QueueStopEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
QueueStopEvent entity
|
QueueStopEvent entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class StopBy(Enum):
|
class StopBy(Enum):
|
||||||
"""
|
"""
|
||||||
Stop by enum
|
Stop by enum
|
||||||
"""
|
"""
|
||||||
|
|
||||||
USER_MANUAL = "user-manual"
|
USER_MANUAL = "user-manual"
|
||||||
ANNOTATION_REPLY = "annotation-reply"
|
ANNOTATION_REPLY = "annotation-reply"
|
||||||
OUTPUT_MODERATION = "output-moderation"
|
OUTPUT_MODERATION = "output-moderation"
|
||||||
@ -126,6 +146,7 @@ class QueueMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
QueueMessage entity
|
QueueMessage entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task_id: str
|
task_id: str
|
||||||
message_id: str
|
message_id: str
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
|
|||||||
@ -2,23 +2,40 @@ from collections.abc import Generator
|
|||||||
from typing import IO, Optional, Union, cast
|
from typing import IO, Optional, Union, cast
|
||||||
|
|
||||||
from model_providers.core.entities.provider_configuration import ProviderModelBundle
|
from model_providers.core.entities.provider_configuration import ProviderModelBundle
|
||||||
from model_providers.errors.error import ProviderTokenNotInitError
|
|
||||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
|
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
TextEmbeddingResult,
|
||||||
from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
LargeLanguageModel,
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.moderation_model import (
|
||||||
|
ModerationModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.rerank_model import (
|
||||||
|
RerankModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import (
|
||||||
|
Speech2TextModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
from model_providers.core.provider_manager import ProviderManager
|
from model_providers.core.provider_manager import ProviderManager
|
||||||
|
from model_providers.errors.error import ProviderTokenNotInitError
|
||||||
|
|
||||||
|
|
||||||
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
def _fetch_credentials_from_bundle(
|
||||||
|
provider_model_bundle: ProviderModelBundle, model: str
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Fetch credentials from provider model bundle
|
Fetch credentials from provider model bundle
|
||||||
:param provider_model_bundle: provider model bundle
|
:param provider_model_bundle: provider model bundle
|
||||||
@ -26,12 +43,13 @@ def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, m
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
credentials = provider_model_bundle.configuration.get_current_credentials(
|
credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||||
model_type=provider_model_bundle.model_type_instance.model_type,
|
model_type=provider_model_bundle.model_type_instance.model_type, model=model
|
||||||
model=model
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if credentials is None:
|
if credentials is None:
|
||||||
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
|
raise ProviderTokenNotInitError(
|
||||||
|
f"Model {model} credentials is not initialized."
|
||||||
|
)
|
||||||
|
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
@ -48,10 +66,16 @@ class ModelInstance:
|
|||||||
self.credentials = _fetch_credentials_from_bundle(provider_model_bundle, model)
|
self.credentials = _fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||||
self.model_type_instance = self._provider_model_bundle.model_type_instance
|
self.model_type_instance = self._provider_model_bundle.model_type_instance
|
||||||
|
|
||||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
def invoke_llm(
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
prompt_messages: list[PromptMessage],
|
||||||
-> Union[LLMResult, Generator]:
|
model_parameters: Optional[dict] = None,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -77,11 +101,12 @@ class ModelInstance:
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
|
def invoke_text_embedding(
|
||||||
-> TextEmbeddingResult:
|
self, texts: list[str], user: Optional[str] = None
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -94,16 +119,17 @@ class ModelInstance:
|
|||||||
|
|
||||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||||
return self.model_type_instance.invoke(
|
return self.model_type_instance.invoke(
|
||||||
model=self.model,
|
model=self.model, credentials=self.credentials, texts=texts, user=user
|
||||||
credentials=self.credentials,
|
|
||||||
texts=texts,
|
|
||||||
user=user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
|
def invoke_rerank(
|
||||||
top_n: Optional[int] = None,
|
self,
|
||||||
user: Optional[str] = None) \
|
query: str,
|
||||||
-> RerankResult:
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
"""
|
"""
|
||||||
Invoke rerank model
|
Invoke rerank model
|
||||||
|
|
||||||
@ -125,11 +151,10 @@ class ModelInstance:
|
|||||||
docs=docs,
|
docs=docs,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
top_n=top_n,
|
top_n=top_n,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_moderation(self, text: str, user: Optional[str] = None) \
|
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
|
||||||
-> bool:
|
|
||||||
"""
|
"""
|
||||||
Invoke moderation model
|
Invoke moderation model
|
||||||
|
|
||||||
@ -142,14 +167,10 @@ class ModelInstance:
|
|||||||
|
|
||||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||||
return self.model_type_instance.invoke(
|
return self.model_type_instance.invoke(
|
||||||
model=self.model,
|
model=self.model, credentials=self.credentials, text=text, user=user
|
||||||
credentials=self.credentials,
|
|
||||||
text=text,
|
|
||||||
user=user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||||
-> str:
|
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -162,14 +183,17 @@ class ModelInstance:
|
|||||||
|
|
||||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||||
return self.model_type_instance.invoke(
|
return self.model_type_instance.invoke(
|
||||||
model=self.model,
|
model=self.model, credentials=self.credentials, file=file, user=user
|
||||||
credentials=self.credentials,
|
|
||||||
file=file,
|
|
||||||
user=user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
|
def invoke_tts(
|
||||||
-> str:
|
self,
|
||||||
|
content_text: str,
|
||||||
|
tenant_id: str,
|
||||||
|
voice: str,
|
||||||
|
streaming: bool,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Invoke large language tts model
|
Invoke large language tts model
|
||||||
|
|
||||||
@ -191,7 +215,7 @@ class ModelInstance:
|
|||||||
user=user,
|
user=user,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
voice=voice,
|
voice=voice,
|
||||||
streaming=streaming
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tts_voices(self, language: str) -> list:
|
def get_tts_voices(self, language: str) -> list:
|
||||||
@ -206,21 +230,24 @@ class ModelInstance:
|
|||||||
|
|
||||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||||
return self.model_type_instance.get_tts_model_voices(
|
return self.model_type_instance.get_tts_model_voices(
|
||||||
model=self.model,
|
model=self.model, credentials=self.credentials, language=language
|
||||||
credentials=self.credentials,
|
|
||||||
language=language
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
def __init__(self,
|
def __init__(
|
||||||
provider_name_to_provider_records_dict: dict,
|
self,
|
||||||
provider_name_to_provider_model_records_dict: dict) -> None:
|
provider_name_to_provider_records_dict: dict,
|
||||||
|
provider_name_to_provider_model_records_dict: dict,
|
||||||
|
) -> None:
|
||||||
self._provider_manager = ProviderManager(
|
self._provider_manager = ProviderManager(
|
||||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict)
|
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
||||||
|
)
|
||||||
|
|
||||||
def get_model_instance(self, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
def get_model_instance(
|
||||||
|
self, provider: str, model_type: ModelType, model: str
|
||||||
|
) -> ModelInstance:
|
||||||
"""
|
"""
|
||||||
Get model instance
|
Get model instance
|
||||||
:param provider: provider name
|
:param provider: provider name
|
||||||
@ -231,8 +258,7 @@ class ModelManager:
|
|||||||
if not provider:
|
if not provider:
|
||||||
return self.get_default_model_instance(model_type)
|
return self.get_default_model_instance(model_type)
|
||||||
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
||||||
provider=provider,
|
provider=provider, model_type=model_type
|
||||||
model_type=model_type
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ModelInstance(provider_model_bundle, model)
|
return ModelInstance(provider_model_bundle, model)
|
||||||
@ -253,5 +279,5 @@ class ModelManager:
|
|||||||
return self.get_model_instance(
|
return self.get_model_instance(
|
||||||
provider=default_model_entity.provider.provider,
|
provider=default_model_entity.provider.provider,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
model=default_model_entity.model
|
model=default_model_entity.model,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,8 +1,14 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
|
||||||
_TEXT_COLOR_MAPPING = {
|
_TEXT_COLOR_MAPPING = {
|
||||||
@ -19,12 +25,21 @@ class Callback(ABC):
|
|||||||
Base class for callbacks.
|
Base class for callbacks.
|
||||||
Only for LLM.
|
Only for LLM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise_error: bool = False
|
raise_error: bool = False
|
||||||
|
|
||||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
def on_before_invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
llm_instance: AIModel,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Before invoke callback
|
Before invoke callback
|
||||||
|
|
||||||
@ -40,10 +55,19 @@ class Callback(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
def on_new_chunk(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
llm_instance: AIModel,
|
||||||
stream: bool = True, user: Optional[str] = None):
|
chunk: LLMResultChunk,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
On new chunk callback
|
On new chunk callback
|
||||||
|
|
||||||
@ -60,10 +84,19 @@ class Callback(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
def on_after_invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
llm_instance: AIModel,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
result: LLMResult,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
After invoke callback
|
After invoke callback
|
||||||
|
|
||||||
@ -80,10 +113,19 @@ class Callback(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
def on_invoke_error(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
llm_instance: AIModel,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
ex: Exception,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Invoke error callback
|
Invoke error callback
|
||||||
|
|
||||||
@ -100,9 +142,7 @@ class Callback(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def print_text(
|
def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||||
self, text: str, color: Optional[str] = None, end: str = ""
|
|
||||||
) -> None:
|
|
||||||
"""Print text with highlighting and no end characters."""
|
"""Print text with highlighting and no end characters."""
|
||||||
text_to_print = self._get_colored_text(text, color) if color else text
|
text_to_print = self._get_colored_text(text, color) if color else text
|
||||||
print(text_to_print, end=end)
|
print(text_to_print, end=end)
|
||||||
|
|||||||
@ -4,17 +4,32 @@ import sys
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoggingCallback(Callback):
|
class LoggingCallback(Callback):
|
||||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
def on_before_invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
llm_instance: AIModel,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Before invoke callback
|
Before invoke callback
|
||||||
|
|
||||||
@ -28,40 +43,49 @@ class LoggingCallback(Callback):
|
|||||||
:param stream: is stream response
|
:param stream: is stream response
|
||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
"""
|
"""
|
||||||
self.print_text("\n[on_llm_before_invoke]\n", color='blue')
|
self.print_text("\n[on_llm_before_invoke]\n", color="blue")
|
||||||
self.print_text(f"Model: {model}\n", color='blue')
|
self.print_text(f"Model: {model}\n", color="blue")
|
||||||
self.print_text("Parameters:\n", color='blue')
|
self.print_text("Parameters:\n", color="blue")
|
||||||
for key, value in model_parameters.items():
|
for key, value in model_parameters.items():
|
||||||
self.print_text(f"\t{key}: {value}\n", color='blue')
|
self.print_text(f"\t{key}: {value}\n", color="blue")
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
self.print_text(f"\tstop: {stop}\n", color='blue')
|
self.print_text(f"\tstop: {stop}\n", color="blue")
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
self.print_text("\tTools:\n", color='blue')
|
self.print_text("\tTools:\n", color="blue")
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
self.print_text(f"\t\t{tool.name}\n", color='blue')
|
self.print_text(f"\t\t{tool.name}\n", color="blue")
|
||||||
|
|
||||||
self.print_text(f"Stream: {stream}\n", color='blue')
|
self.print_text(f"Stream: {stream}\n", color="blue")
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
self.print_text(f"User: {user}\n", color='blue')
|
self.print_text(f"User: {user}\n", color="blue")
|
||||||
|
|
||||||
self.print_text("Prompt messages:\n", color='blue')
|
self.print_text("Prompt messages:\n", color="blue")
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
if prompt_message.name:
|
if prompt_message.name:
|
||||||
self.print_text(f"\tname: {prompt_message.name}\n", color='blue')
|
self.print_text(f"\tname: {prompt_message.name}\n", color="blue")
|
||||||
|
|
||||||
self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue')
|
self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue")
|
||||||
self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue')
|
self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
self.print_text("\n[on_llm_new_chunk]")
|
self.print_text("\n[on_llm_new_chunk]")
|
||||||
|
|
||||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
def on_new_chunk(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
llm_instance: AIModel,
|
||||||
stream: bool = True, user: Optional[str] = None):
|
chunk: LLMResultChunk,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
On new chunk callback
|
On new chunk callback
|
||||||
|
|
||||||
@ -79,10 +103,19 @@ class LoggingCallback(Callback):
|
|||||||
sys.stdout.write(chunk.delta.message.content)
|
sys.stdout.write(chunk.delta.message.content)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
def on_after_invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
llm_instance: AIModel,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
result: LLMResult,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
After invoke callback
|
After invoke callback
|
||||||
|
|
||||||
@ -97,24 +130,37 @@ class LoggingCallback(Callback):
|
|||||||
:param stream: is stream response
|
:param stream: is stream response
|
||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
"""
|
"""
|
||||||
self.print_text("\n[on_llm_after_invoke]\n", color='yellow')
|
self.print_text("\n[on_llm_after_invoke]\n", color="yellow")
|
||||||
self.print_text(f"Content: {result.message.content}\n", color='yellow')
|
self.print_text(f"Content: {result.message.content}\n", color="yellow")
|
||||||
|
|
||||||
if result.message.tool_calls:
|
if result.message.tool_calls:
|
||||||
self.print_text("Tool calls:\n", color='yellow')
|
self.print_text("Tool calls:\n", color="yellow")
|
||||||
for tool_call in result.message.tool_calls:
|
for tool_call in result.message.tool_calls:
|
||||||
self.print_text(f"\t{tool_call.id}\n", color='yellow')
|
self.print_text(f"\t{tool_call.id}\n", color="yellow")
|
||||||
self.print_text(f"\t{tool_call.function.name}\n", color='yellow')
|
self.print_text(f"\t{tool_call.function.name}\n", color="yellow")
|
||||||
self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow')
|
self.print_text(
|
||||||
|
f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow"
|
||||||
|
)
|
||||||
|
|
||||||
self.print_text(f"Model: {result.model}\n", color='yellow')
|
self.print_text(f"Model: {result.model}\n", color="yellow")
|
||||||
self.print_text(f"Usage: {result.usage}\n", color='yellow')
|
self.print_text(f"Usage: {result.usage}\n", color="yellow")
|
||||||
self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow')
|
self.print_text(
|
||||||
|
f"System Fingerprint: {result.system_fingerprint}\n", color="yellow"
|
||||||
|
)
|
||||||
|
|
||||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
def on_invoke_error(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
llm_instance: AIModel,
|
||||||
stream: bool = True, user: Optional[str] = None) -> None:
|
ex: Exception,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Invoke error callback
|
Invoke error callback
|
||||||
|
|
||||||
@ -129,5 +175,5 @@ class LoggingCallback(Callback):
|
|||||||
:param stream: is stream response
|
:param stream: is stream response
|
||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
"""
|
"""
|
||||||
self.print_text("\n[on_llm_invoke_error]\n", color='red')
|
self.print_text("\n[on_llm_invoke_error]\n", color="red")
|
||||||
logger.exception(ex)
|
logger.exception(ex)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ class I18nObject(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for i18n object.
|
Model class for i18n object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
zh_Hans: Optional[str] = None
|
zh_Hans: Optional[str] = None
|
||||||
en_US: str
|
en_US: str
|
||||||
|
|
||||||
|
|||||||
@ -1,98 +1,99 @@
|
|||||||
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
from model_providers.core.model_runtime.entities.model_entities import DefaultParameterName
|
DefaultParameterName,
|
||||||
|
)
|
||||||
|
|
||||||
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
||||||
DefaultParameterName.TEMPERATURE: {
|
DefaultParameterName.TEMPERATURE: {
|
||||||
'label': {
|
"label": {
|
||||||
'en_US': 'Temperature',
|
"en_US": "Temperature",
|
||||||
'zh_Hans': '温度',
|
"zh_Hans": "温度",
|
||||||
},
|
},
|
||||||
'type': 'float',
|
"type": "float",
|
||||||
'help': {
|
"help": {
|
||||||
'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.',
|
"en_US": "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
|
||||||
'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。',
|
"zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
|
||||||
},
|
},
|
||||||
'required': False,
|
"required": False,
|
||||||
'default': 0.0,
|
"default": 0.0,
|
||||||
'min': 0.0,
|
"min": 0.0,
|
||||||
'max': 1.0,
|
"max": 1.0,
|
||||||
'precision': 2,
|
"precision": 2,
|
||||||
},
|
},
|
||||||
DefaultParameterName.TOP_P: {
|
DefaultParameterName.TOP_P: {
|
||||||
'label': {
|
"label": {
|
||||||
'en_US': 'Top P',
|
"en_US": "Top P",
|
||||||
'zh_Hans': 'Top P',
|
"zh_Hans": "Top P",
|
||||||
},
|
},
|
||||||
'type': 'float',
|
"type": "float",
|
||||||
'help': {
|
"help": {
|
||||||
'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.',
|
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
|
||||||
'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。',
|
"zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
|
||||||
},
|
},
|
||||||
'required': False,
|
"required": False,
|
||||||
'default': 1.0,
|
"default": 1.0,
|
||||||
'min': 0.0,
|
"min": 0.0,
|
||||||
'max': 1.0,
|
"max": 1.0,
|
||||||
'precision': 2,
|
"precision": 2,
|
||||||
},
|
},
|
||||||
DefaultParameterName.PRESENCE_PENALTY: {
|
DefaultParameterName.PRESENCE_PENALTY: {
|
||||||
'label': {
|
"label": {
|
||||||
'en_US': 'Presence Penalty',
|
"en_US": "Presence Penalty",
|
||||||
'zh_Hans': '存在惩罚',
|
"zh_Hans": "存在惩罚",
|
||||||
},
|
},
|
||||||
'type': 'float',
|
"type": "float",
|
||||||
'help': {
|
"help": {
|
||||||
'en_US': 'Applies a penalty to the log-probability of tokens already in the text.',
|
"en_US": "Applies a penalty to the log-probability of tokens already in the text.",
|
||||||
'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。',
|
"zh_Hans": "对文本中已有的标记的对数概率施加惩罚。",
|
||||||
},
|
},
|
||||||
'required': False,
|
"required": False,
|
||||||
'default': 0.0,
|
"default": 0.0,
|
||||||
'min': 0.0,
|
"min": 0.0,
|
||||||
'max': 1.0,
|
"max": 1.0,
|
||||||
'precision': 2,
|
"precision": 2,
|
||||||
},
|
},
|
||||||
DefaultParameterName.FREQUENCY_PENALTY: {
|
DefaultParameterName.FREQUENCY_PENALTY: {
|
||||||
'label': {
|
"label": {
|
||||||
'en_US': 'Frequency Penalty',
|
"en_US": "Frequency Penalty",
|
||||||
'zh_Hans': '频率惩罚',
|
"zh_Hans": "频率惩罚",
|
||||||
},
|
},
|
||||||
'type': 'float',
|
"type": "float",
|
||||||
'help': {
|
"help": {
|
||||||
'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.',
|
"en_US": "Applies a penalty to the log-probability of tokens that appear in the text.",
|
||||||
'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。',
|
"zh_Hans": "对文本中出现的标记的对数概率施加惩罚。",
|
||||||
},
|
},
|
||||||
'required': False,
|
"required": False,
|
||||||
'default': 0.0,
|
"default": 0.0,
|
||||||
'min': 0.0,
|
"min": 0.0,
|
||||||
'max': 1.0,
|
"max": 1.0,
|
||||||
'precision': 2,
|
"precision": 2,
|
||||||
},
|
},
|
||||||
DefaultParameterName.MAX_TOKENS: {
|
DefaultParameterName.MAX_TOKENS: {
|
||||||
'label': {
|
"label": {
|
||||||
'en_US': 'Max Tokens',
|
"en_US": "Max Tokens",
|
||||||
'zh_Hans': '最大标记',
|
"zh_Hans": "最大标记",
|
||||||
},
|
},
|
||||||
'type': 'int',
|
"type": "int",
|
||||||
'help': {
|
"help": {
|
||||||
'en_US': 'The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.',
|
"en_US": "The maximum number of tokens to generate. Requests can use up to 2048 tokens shared between prompt and completion.",
|
||||||
'zh_Hans': '要生成的标记的最大数量。请求可以使用最多2048个标记,这些标记在提示和完成之间共享。',
|
"zh_Hans": "要生成的标记的最大数量。请求可以使用最多2048个标记,这些标记在提示和完成之间共享。",
|
||||||
},
|
},
|
||||||
'required': False,
|
"required": False,
|
||||||
'default': 64,
|
"default": 64,
|
||||||
'min': 1,
|
"min": 1,
|
||||||
'max': 2048,
|
"max": 2048,
|
||||||
'precision': 0,
|
"precision": 0,
|
||||||
},
|
},
|
||||||
DefaultParameterName.RESPONSE_FORMAT: {
|
DefaultParameterName.RESPONSE_FORMAT: {
|
||||||
'label': {
|
"label": {
|
||||||
'en_US': 'Response Format',
|
"en_US": "Response Format",
|
||||||
'zh_Hans': '回复格式',
|
"zh_Hans": "回复格式",
|
||||||
},
|
},
|
||||||
'type': 'string',
|
"type": "string",
|
||||||
'help': {
|
"help": {
|
||||||
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
|
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
|
||||||
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
|
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
|
||||||
},
|
},
|
||||||
'required': False,
|
"required": False,
|
||||||
'options': ['JSON', 'XML'],
|
"options": ["JSON", "XML"],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,19 +4,26 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
ModelUsage,
|
||||||
|
PriceInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMMode(Enum):
|
class LLMMode(Enum):
|
||||||
"""
|
"""
|
||||||
Enum class for large language model mode.
|
Enum class for large language model mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
COMPLETION = "completion"
|
COMPLETION = "completion"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> 'LLMMode':
|
def value_of(cls, value: str) -> "LLMMode":
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -26,13 +33,14 @@ class LLMMode(Enum):
|
|||||||
for mode in cls:
|
for mode in cls:
|
||||||
if mode.value == value:
|
if mode.value == value:
|
||||||
return mode
|
return mode
|
||||||
raise ValueError(f'invalid mode value {value}')
|
raise ValueError(f"invalid mode value {value}")
|
||||||
|
|
||||||
|
|
||||||
class LLMUsage(ModelUsage):
|
class LLMUsage(ModelUsage):
|
||||||
"""
|
"""
|
||||||
Model class for llm usage.
|
Model class for llm usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
prompt_unit_price: Decimal
|
prompt_unit_price: Decimal
|
||||||
prompt_price_unit: Decimal
|
prompt_price_unit: Decimal
|
||||||
@ -50,17 +58,17 @@ class LLMUsage(ModelUsage):
|
|||||||
def empty_usage(cls):
|
def empty_usage(cls):
|
||||||
return cls(
|
return cls(
|
||||||
prompt_tokens=0,
|
prompt_tokens=0,
|
||||||
prompt_unit_price=Decimal('0.0'),
|
prompt_unit_price=Decimal("0.0"),
|
||||||
prompt_price_unit=Decimal('0.0'),
|
prompt_price_unit=Decimal("0.0"),
|
||||||
prompt_price=Decimal('0.0'),
|
prompt_price=Decimal("0.0"),
|
||||||
completion_tokens=0,
|
completion_tokens=0,
|
||||||
completion_unit_price=Decimal('0.0'),
|
completion_unit_price=Decimal("0.0"),
|
||||||
completion_price_unit=Decimal('0.0'),
|
completion_price_unit=Decimal("0.0"),
|
||||||
completion_price=Decimal('0.0'),
|
completion_price=Decimal("0.0"),
|
||||||
total_tokens=0,
|
total_tokens=0,
|
||||||
total_price=Decimal('0.0'),
|
total_price=Decimal("0.0"),
|
||||||
currency='USD',
|
currency="USD",
|
||||||
latency=0.0
|
latency=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -68,6 +76,7 @@ class LLMResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for llm result.
|
Model class for llm result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
prompt_messages: list[PromptMessage]
|
prompt_messages: list[PromptMessage]
|
||||||
message: AssistantPromptMessage
|
message: AssistantPromptMessage
|
||||||
@ -79,6 +88,7 @@ class LLMResultChunkDelta(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for llm result chunk delta.
|
Model class for llm result chunk delta.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
index: int
|
index: int
|
||||||
message: AssistantPromptMessage
|
message: AssistantPromptMessage
|
||||||
usage: Optional[LLMUsage] = None
|
usage: Optional[LLMUsage] = None
|
||||||
@ -89,6 +99,7 @@ class LLMResultChunk(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for llm result chunk.
|
Model class for llm result chunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
prompt_messages: list[PromptMessage]
|
prompt_messages: list[PromptMessage]
|
||||||
system_fingerprint: Optional[str] = None
|
system_fingerprint: Optional[str] = None
|
||||||
@ -99,4 +110,5 @@ class NumTokensResult(PriceInfo):
|
|||||||
"""
|
"""
|
||||||
Model class for number of tokens result.
|
Model class for number of tokens result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokens: int
|
tokens: int
|
||||||
|
|||||||
@ -9,13 +9,14 @@ class PromptMessageRole(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for prompt message.
|
Enum class for prompt message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> 'PromptMessageRole':
|
def value_of(cls, value: str) -> "PromptMessageRole":
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -25,13 +26,14 @@ class PromptMessageRole(Enum):
|
|||||||
for mode in cls:
|
for mode in cls:
|
||||||
if mode.value == value:
|
if mode.value == value:
|
||||||
return mode
|
return mode
|
||||||
raise ValueError(f'invalid prompt message type value {value}')
|
raise ValueError(f"invalid prompt message type value {value}")
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageTool(BaseModel):
|
class PromptMessageTool(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for prompt message tool.
|
Model class for prompt message tool.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
parameters: dict
|
parameters: dict
|
||||||
@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for prompt message function.
|
Model class for prompt message function.
|
||||||
"""
|
"""
|
||||||
type: str = 'function'
|
|
||||||
|
type: str = "function"
|
||||||
function: PromptMessageTool
|
function: PromptMessageTool
|
||||||
|
|
||||||
|
|
||||||
@ -49,14 +52,16 @@ class PromptMessageContentType(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for prompt message content type.
|
Enum class for prompt message content type.
|
||||||
"""
|
"""
|
||||||
TEXT = 'text'
|
|
||||||
IMAGE = 'image'
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageContent(BaseModel):
|
class PromptMessageContent(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for prompt message content.
|
Model class for prompt message content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: PromptMessageContentType
|
type: PromptMessageContentType
|
||||||
data: str
|
data: str
|
||||||
|
|
||||||
@ -65,6 +70,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
|||||||
"""
|
"""
|
||||||
Model class for text prompt message content.
|
Model class for text prompt message content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||||
|
|
||||||
|
|
||||||
@ -72,9 +78,10 @@ class ImagePromptMessageContent(PromptMessageContent):
|
|||||||
"""
|
"""
|
||||||
Model class for image prompt message content.
|
Model class for image prompt message content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class DETAIL(Enum):
|
class DETAIL(Enum):
|
||||||
LOW = 'low'
|
LOW = "low"
|
||||||
HIGH = 'high'
|
HIGH = "high"
|
||||||
|
|
||||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||||
detail: DETAIL = DETAIL.LOW
|
detail: DETAIL = DETAIL.LOW
|
||||||
@ -84,6 +91,7 @@ class PromptMessage(ABC, BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for prompt message.
|
Model class for prompt message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: PromptMessageRole
|
role: PromptMessageRole
|
||||||
content: Optional[str | list[PromptMessageContent]] = None
|
content: Optional[str | list[PromptMessageContent]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@ -93,6 +101,7 @@ class UserPromptMessage(PromptMessage):
|
|||||||
"""
|
"""
|
||||||
Model class for user prompt message.
|
Model class for user prompt message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: PromptMessageRole = PromptMessageRole.USER
|
role: PromptMessageRole = PromptMessageRole.USER
|
||||||
|
|
||||||
|
|
||||||
@ -100,14 +109,17 @@ class AssistantPromptMessage(PromptMessage):
|
|||||||
"""
|
"""
|
||||||
Model class for assistant prompt message.
|
Model class for assistant prompt message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for assistant prompt message tool call.
|
Model class for assistant prompt message tool call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class ToolCallFunction(BaseModel):
|
class ToolCallFunction(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for assistant prompt message tool call function.
|
Model class for assistant prompt message tool call function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
arguments: str
|
arguments: str
|
||||||
|
|
||||||
@ -123,6 +135,7 @@ class SystemPromptMessage(PromptMessage):
|
|||||||
"""
|
"""
|
||||||
Model class for system prompt message.
|
Model class for system prompt message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
role: PromptMessageRole = PromptMessageRole.SYSTEM
|
||||||
|
|
||||||
|
|
||||||
@ -130,5 +143,6 @@ class ToolPromptMessage(PromptMessage):
|
|||||||
"""
|
"""
|
||||||
Model class for tool prompt message.
|
Model class for tool prompt message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|||||||
@ -11,6 +11,7 @@ class ModelType(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for model type.
|
Enum class for model type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LLM = "llm"
|
LLM = "llm"
|
||||||
TEXT_EMBEDDING = "text-embedding"
|
TEXT_EMBEDDING = "text-embedding"
|
||||||
RERANK = "rerank"
|
RERANK = "rerank"
|
||||||
@ -26,22 +27,28 @@ class ModelType(Enum):
|
|||||||
|
|
||||||
:return: model type
|
:return: model type
|
||||||
"""
|
"""
|
||||||
if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value:
|
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value:
|
||||||
return cls.LLM
|
return cls.LLM
|
||||||
elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value:
|
elif (
|
||||||
|
origin_model_type == "embeddings"
|
||||||
|
or origin_model_type == cls.TEXT_EMBEDDING.value
|
||||||
|
):
|
||||||
return cls.TEXT_EMBEDDING
|
return cls.TEXT_EMBEDDING
|
||||||
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
|
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value:
|
||||||
return cls.RERANK
|
return cls.RERANK
|
||||||
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
|
elif (
|
||||||
|
origin_model_type == "speech2text"
|
||||||
|
or origin_model_type == cls.SPEECH2TEXT.value
|
||||||
|
):
|
||||||
return cls.SPEECH2TEXT
|
return cls.SPEECH2TEXT
|
||||||
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
|
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value:
|
||||||
return cls.TTS
|
return cls.TTS
|
||||||
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
|
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value:
|
||||||
return cls.TEXT2IMG
|
return cls.TEXT2IMG
|
||||||
elif origin_model_type == cls.MODERATION.value:
|
elif origin_model_type == cls.MODERATION.value:
|
||||||
return cls.MODERATION
|
return cls.MODERATION
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'invalid origin model type {origin_model_type}')
|
raise ValueError(f"invalid origin model type {origin_model_type}")
|
||||||
|
|
||||||
def to_origin_model_type(self) -> str:
|
def to_origin_model_type(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -50,26 +57,28 @@ class ModelType(Enum):
|
|||||||
:return: origin model type
|
:return: origin model type
|
||||||
"""
|
"""
|
||||||
if self == self.LLM:
|
if self == self.LLM:
|
||||||
return 'text-generation'
|
return "text-generation"
|
||||||
elif self == self.TEXT_EMBEDDING:
|
elif self == self.TEXT_EMBEDDING:
|
||||||
return 'embeddings'
|
return "embeddings"
|
||||||
elif self == self.RERANK:
|
elif self == self.RERANK:
|
||||||
return 'reranking'
|
return "reranking"
|
||||||
elif self == self.SPEECH2TEXT:
|
elif self == self.SPEECH2TEXT:
|
||||||
return 'speech2text'
|
return "speech2text"
|
||||||
elif self == self.TTS:
|
elif self == self.TTS:
|
||||||
return 'tts'
|
return "tts"
|
||||||
elif self == self.MODERATION:
|
elif self == self.MODERATION:
|
||||||
return 'moderation'
|
return "moderation"
|
||||||
elif self == self.TEXT2IMG:
|
elif self == self.TEXT2IMG:
|
||||||
return 'text2img'
|
return "text2img"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'invalid model type {self}')
|
raise ValueError(f"invalid model type {self}")
|
||||||
|
|
||||||
|
|
||||||
class FetchFrom(Enum):
|
class FetchFrom(Enum):
|
||||||
"""
|
"""
|
||||||
Enum class for fetch from.
|
Enum class for fetch from.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PREDEFINED_MODEL = "predefined-model"
|
PREDEFINED_MODEL = "predefined-model"
|
||||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||||
|
|
||||||
@ -78,6 +87,7 @@ class ModelFeature(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for llm feature.
|
Enum class for llm feature.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TOOL_CALL = "tool-call"
|
TOOL_CALL = "tool-call"
|
||||||
MULTI_TOOL_CALL = "multi-tool-call"
|
MULTI_TOOL_CALL = "multi-tool-call"
|
||||||
AGENT_THOUGHT = "agent-thought"
|
AGENT_THOUGHT = "agent-thought"
|
||||||
@ -89,6 +99,7 @@ class DefaultParameterName(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for parameter template variable.
|
Enum class for parameter template variable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TEMPERATURE = "temperature"
|
TEMPERATURE = "temperature"
|
||||||
TOP_P = "top_p"
|
TOP_P = "top_p"
|
||||||
PRESENCE_PENALTY = "presence_penalty"
|
PRESENCE_PENALTY = "presence_penalty"
|
||||||
@ -97,7 +108,7 @@ class DefaultParameterName(Enum):
|
|||||||
RESPONSE_FORMAT = "response_format"
|
RESPONSE_FORMAT = "response_format"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
def value_of(cls, value: Any) -> "DefaultParameterName":
|
||||||
"""
|
"""
|
||||||
Get parameter name from value.
|
Get parameter name from value.
|
||||||
|
|
||||||
@ -107,13 +118,14 @@ class DefaultParameterName(Enum):
|
|||||||
for name in cls:
|
for name in cls:
|
||||||
if name.value == value:
|
if name.value == value:
|
||||||
return name
|
return name
|
||||||
raise ValueError(f'invalid parameter name {value}')
|
raise ValueError(f"invalid parameter name {value}")
|
||||||
|
|
||||||
|
|
||||||
class ParameterType(Enum):
|
class ParameterType(Enum):
|
||||||
"""
|
"""
|
||||||
Enum class for parameter type.
|
Enum class for parameter type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FLOAT = "float"
|
FLOAT = "float"
|
||||||
INT = "int"
|
INT = "int"
|
||||||
STRING = "string"
|
STRING = "string"
|
||||||
@ -124,6 +136,7 @@ class ModelPropertyKey(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for model property key.
|
Enum class for model property key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MODE = "mode"
|
MODE = "mode"
|
||||||
CONTEXT_SIZE = "context_size"
|
CONTEXT_SIZE = "context_size"
|
||||||
MAX_CHUNKS = "max_chunks"
|
MAX_CHUNKS = "max_chunks"
|
||||||
@ -141,6 +154,7 @@ class ProviderModel(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider model.
|
Model class for provider model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
label: I18nObject
|
label: I18nObject
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
@ -157,6 +171,7 @@ class ParameterRule(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for parameter rule.
|
Model class for parameter rule.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
use_template: Optional[str] = None
|
use_template: Optional[str] = None
|
||||||
label: I18nObject
|
label: I18nObject
|
||||||
@ -174,6 +189,7 @@ class PriceConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for pricing info.
|
Model class for pricing info.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input: Decimal
|
input: Decimal
|
||||||
output: Optional[Decimal] = None
|
output: Optional[Decimal] = None
|
||||||
unit: Decimal
|
unit: Decimal
|
||||||
@ -184,6 +200,7 @@ class AIModelEntity(ProviderModel):
|
|||||||
"""
|
"""
|
||||||
Model class for AI model.
|
Model class for AI model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parameter_rules: list[ParameterRule] = []
|
parameter_rules: list[ParameterRule] = []
|
||||||
pricing: Optional[PriceConfig] = None
|
pricing: Optional[PriceConfig] = None
|
||||||
|
|
||||||
@ -196,6 +213,7 @@ class PriceType(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for price type.
|
Enum class for price type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INPUT = "input"
|
INPUT = "input"
|
||||||
OUTPUT = "output"
|
OUTPUT = "output"
|
||||||
|
|
||||||
@ -204,6 +222,7 @@ class PriceInfo(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for price info.
|
Model class for price info.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
unit_price: Decimal
|
unit_price: Decimal
|
||||||
unit: Decimal
|
unit: Decimal
|
||||||
total_amount: Decimal
|
total_amount: Decimal
|
||||||
|
|||||||
@ -4,13 +4,18 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
ModelType,
|
||||||
|
ProviderModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConfigurateMethod(Enum):
|
class ConfigurateMethod(Enum):
|
||||||
"""
|
"""
|
||||||
Enum class for configurate method of provider model.
|
Enum class for configurate method of provider model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PREDEFINED_MODEL = "predefined-model"
|
PREDEFINED_MODEL = "predefined-model"
|
||||||
CUSTOMIZABLE_MODEL = "customizable-model"
|
CUSTOMIZABLE_MODEL = "customizable-model"
|
||||||
|
|
||||||
@ -19,6 +24,7 @@ class FormType(Enum):
|
|||||||
"""
|
"""
|
||||||
Enum class for form type.
|
Enum class for form type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TEXT_INPUT = "text-input"
|
TEXT_INPUT = "text-input"
|
||||||
SECRET_INPUT = "secret-input"
|
SECRET_INPUT = "secret-input"
|
||||||
SELECT = "select"
|
SELECT = "select"
|
||||||
@ -30,6 +36,7 @@ class FormShowOnObject(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for form show on.
|
Model class for form show on.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
variable: str
|
variable: str
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
@ -38,6 +45,7 @@ class FormOption(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for form option.
|
Model class for form option.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
label: I18nObject
|
label: I18nObject
|
||||||
value: str
|
value: str
|
||||||
show_on: list[FormShowOnObject] = []
|
show_on: list[FormShowOnObject] = []
|
||||||
@ -45,15 +53,14 @@ class FormOption(BaseModel):
|
|||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
if not self.label:
|
if not self.label:
|
||||||
self.label = I18nObject(
|
self.label = I18nObject(en_US=self.value)
|
||||||
en_US=self.value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CredentialFormSchema(BaseModel):
|
class CredentialFormSchema(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for credential form schema.
|
Model class for credential form schema.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
variable: str
|
variable: str
|
||||||
label: I18nObject
|
label: I18nObject
|
||||||
type: FormType
|
type: FormType
|
||||||
@ -69,6 +76,7 @@ class ProviderCredentialSchema(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider credential schema.
|
Model class for provider credential schema.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
credential_form_schemas: list[CredentialFormSchema]
|
credential_form_schemas: list[CredentialFormSchema]
|
||||||
|
|
||||||
|
|
||||||
@ -81,6 +89,7 @@ class ModelCredentialSchema(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for model credential schema.
|
Model class for model credential schema.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: FieldModelSchema
|
model: FieldModelSchema
|
||||||
credential_form_schemas: list[CredentialFormSchema]
|
credential_form_schemas: list[CredentialFormSchema]
|
||||||
|
|
||||||
@ -89,6 +98,7 @@ class SimpleProviderEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Simple model class for provider.
|
Simple model class for provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
label: I18nObject
|
label: I18nObject
|
||||||
icon_small: Optional[I18nObject] = None
|
icon_small: Optional[I18nObject] = None
|
||||||
@ -101,6 +111,7 @@ class ProviderHelpEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider help.
|
Model class for provider help.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
title: I18nObject
|
title: I18nObject
|
||||||
url: I18nObject
|
url: I18nObject
|
||||||
|
|
||||||
@ -109,6 +120,7 @@ class ProviderEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider.
|
Model class for provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
label: I18nObject
|
label: I18nObject
|
||||||
description: Optional[I18nObject] = None
|
description: Optional[I18nObject] = None
|
||||||
@ -137,7 +149,7 @@ class ProviderEntity(BaseModel):
|
|||||||
icon_small=self.icon_small,
|
icon_small=self.icon_small,
|
||||||
icon_large=self.icon_large,
|
icon_large=self.icon_large,
|
||||||
supported_model_types=self.supported_model_types,
|
supported_model_types=self.supported_model_types,
|
||||||
models=self.models
|
models=self.models,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -145,5 +157,6 @@ class ProviderConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for provider config.
|
Model class for provider config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
credentials: dict
|
credentials: dict
|
||||||
|
|||||||
@ -5,6 +5,7 @@ class RerankDocument(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for rerank document.
|
Model class for rerank document.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
score: float
|
score: float
|
||||||
@ -14,5 +15,6 @@ class RerankResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for rerank result.
|
Model class for rerank result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
docs: list[RerankDocument]
|
docs: list[RerankDocument]
|
||||||
|
|||||||
@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage):
|
|||||||
"""
|
"""
|
||||||
Model class for embedding usage.
|
Model class for embedding usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokens: int
|
tokens: int
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
unit_price: Decimal
|
unit_price: Decimal
|
||||||
@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Model class for text embedding result.
|
Model class for text embedding result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
embeddings: list[list[float]]
|
embeddings: list[list[float]]
|
||||||
usage: EmbeddingUsage
|
usage: EmbeddingUsage
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
class InvokeError(Exception):
|
class InvokeError(Exception):
|
||||||
"""Base class for all LLM exceptions."""
|
"""Base class for all LLM exceptions."""
|
||||||
|
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, description: Optional[str] = None) -> None:
|
def __init__(self, description: Optional[str] = None) -> None:
|
||||||
@ -14,24 +15,29 @@ class InvokeError(Exception):
|
|||||||
|
|
||||||
class InvokeConnectionError(InvokeError):
|
class InvokeConnectionError(InvokeError):
|
||||||
"""Raised when the Invoke returns connection error."""
|
"""Raised when the Invoke returns connection error."""
|
||||||
|
|
||||||
description = "Connection Error"
|
description = "Connection Error"
|
||||||
|
|
||||||
|
|
||||||
class InvokeServerUnavailableError(InvokeError):
|
class InvokeServerUnavailableError(InvokeError):
|
||||||
"""Raised when the Invoke returns server unavailable error."""
|
"""Raised when the Invoke returns server unavailable error."""
|
||||||
|
|
||||||
description = "Server Unavailable Error"
|
description = "Server Unavailable Error"
|
||||||
|
|
||||||
|
|
||||||
class InvokeRateLimitError(InvokeError):
|
class InvokeRateLimitError(InvokeError):
|
||||||
"""Raised when the Invoke returns rate limit error."""
|
"""Raised when the Invoke returns rate limit error."""
|
||||||
|
|
||||||
description = "Rate Limit Error"
|
description = "Rate Limit Error"
|
||||||
|
|
||||||
|
|
||||||
class InvokeAuthorizationError(InvokeError):
|
class InvokeAuthorizationError(InvokeError):
|
||||||
"""Raised when the Invoke returns authorization error."""
|
"""Raised when the Invoke returns authorization error."""
|
||||||
|
|
||||||
description = "Incorrect model credentials provided, please check and try again. "
|
description = "Incorrect model credentials provided, please check and try again. "
|
||||||
|
|
||||||
|
|
||||||
class InvokeBadRequestError(InvokeError):
|
class InvokeBadRequestError(InvokeError):
|
||||||
"""Raised when the Invoke returns bad request."""
|
"""Raised when the Invoke returns bad request."""
|
||||||
|
|
||||||
description = "Bad Request Error"
|
description = "Bad Request Error"
|
||||||
|
|||||||
@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception):
|
|||||||
"""
|
"""
|
||||||
Credentials validate failed error
|
Credentials validate failed error
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -16,15 +16,24 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
PriceInfo,
|
PriceInfo,
|
||||||
PriceType,
|
PriceType,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
InvokeAuthorizationError,
|
||||||
from model_providers.core.utils.position_helper import get_position_map, sort_by_position_map
|
InvokeError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import (
|
||||||
|
GPT2Tokenizer,
|
||||||
|
)
|
||||||
|
from model_providers.core.utils.position_helper import (
|
||||||
|
get_position_map,
|
||||||
|
sort_by_position_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AIModel(ABC):
|
class AIModel(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for all models.
|
Base class for all models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
model_schemas: list[AIModelEntity] = None
|
model_schemas: list[AIModelEntity] = None
|
||||||
started_at: float = 0
|
started_at: float = 0
|
||||||
@ -60,18 +69,24 @@ class AIModel(ABC):
|
|||||||
:param error: model invoke error
|
:param error: model invoke error
|
||||||
:return: unified error
|
:return: unified error
|
||||||
"""
|
"""
|
||||||
provider_name = self.__class__.__module__.split('.')[-3]
|
provider_name = self.__class__.__module__.split(".")[-3]
|
||||||
|
|
||||||
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
||||||
if isinstance(error, tuple(model_errors)):
|
if isinstance(error, tuple(model_errors)):
|
||||||
if invoke_error == InvokeAuthorizationError:
|
if invoke_error == InvokeAuthorizationError:
|
||||||
return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ")
|
return invoke_error(
|
||||||
|
description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. "
|
||||||
|
)
|
||||||
|
|
||||||
return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
|
return invoke_error(
|
||||||
|
description=f"[{provider_name}] {invoke_error.description}, {str(error)}"
|
||||||
|
)
|
||||||
|
|
||||||
return InvokeError(description=f"[{provider_name}] Error: {str(error)}")
|
return InvokeError(description=f"[{provider_name}] Error: {str(error)}")
|
||||||
|
|
||||||
def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
|
def get_price(
|
||||||
|
self, model: str, credentials: dict, price_type: PriceType, tokens: int
|
||||||
|
) -> PriceInfo:
|
||||||
"""
|
"""
|
||||||
Get price for given model and tokens
|
Get price for given model and tokens
|
||||||
|
|
||||||
@ -99,15 +114,17 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
if unit_price is None:
|
if unit_price is None:
|
||||||
return PriceInfo(
|
return PriceInfo(
|
||||||
unit_price=decimal.Decimal('0.0'),
|
unit_price=decimal.Decimal("0.0"),
|
||||||
unit=decimal.Decimal('0.0'),
|
unit=decimal.Decimal("0.0"),
|
||||||
total_amount=decimal.Decimal('0.0'),
|
total_amount=decimal.Decimal("0.0"),
|
||||||
currency="USD",
|
currency="USD",
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate total amount
|
# calculate total amount
|
||||||
total_amount = tokens * unit_price * price_config.unit
|
total_amount = tokens * unit_price * price_config.unit
|
||||||
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
|
total_amount = total_amount.quantize(
|
||||||
|
decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP
|
||||||
|
)
|
||||||
|
|
||||||
return PriceInfo(
|
return PriceInfo(
|
||||||
unit_price=unit_price,
|
unit_price=unit_price,
|
||||||
@ -128,24 +145,28 @@ class AIModel(ABC):
|
|||||||
model_schemas = []
|
model_schemas = []
|
||||||
|
|
||||||
# get module name
|
# get module name
|
||||||
model_type = self.__class__.__module__.split('.')[-1]
|
model_type = self.__class__.__module__.split(".")[-1]
|
||||||
|
|
||||||
# get provider name
|
# get provider name
|
||||||
provider_name = self.__class__.__module__.split('.')[-3]
|
provider_name = self.__class__.__module__.split(".")[-3]
|
||||||
|
|
||||||
# get the path of current classes
|
# get the path of current classes
|
||||||
current_path = os.path.abspath(__file__)
|
current_path = os.path.abspath(__file__)
|
||||||
# get parent path of the current path
|
# get parent path of the current path
|
||||||
provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type)
|
provider_model_type_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(current_path)), provider_name, model_type
|
||||||
|
)
|
||||||
|
|
||||||
# get all yaml files path under provider_model_type_path that do not start with __
|
# get all yaml files path under provider_model_type_path that do not start with __
|
||||||
model_schema_yaml_paths = [
|
model_schema_yaml_paths = [
|
||||||
os.path.join(provider_model_type_path, model_schema_yaml)
|
os.path.join(provider_model_type_path, model_schema_yaml)
|
||||||
for model_schema_yaml in os.listdir(provider_model_type_path)
|
for model_schema_yaml in os.listdir(provider_model_type_path)
|
||||||
if not model_schema_yaml.startswith('__')
|
if not model_schema_yaml.startswith("__")
|
||||||
and not model_schema_yaml.startswith('_')
|
and not model_schema_yaml.startswith("_")
|
||||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
and os.path.isfile(
|
||||||
and model_schema_yaml.endswith('.yaml')
|
os.path.join(provider_model_type_path, model_schema_yaml)
|
||||||
|
)
|
||||||
|
and model_schema_yaml.endswith(".yaml")
|
||||||
]
|
]
|
||||||
|
|
||||||
# get _position.yaml file path
|
# get _position.yaml file path
|
||||||
@ -154,59 +175,73 @@ class AIModel(ABC):
|
|||||||
# traverse all model_schema_yaml_paths
|
# traverse all model_schema_yaml_paths
|
||||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||||
# read yaml data from yaml file
|
# read yaml data from yaml file
|
||||||
with open(model_schema_yaml_path, encoding='utf-8') as f:
|
with open(model_schema_yaml_path, encoding="utf-8") as f:
|
||||||
yaml_data = yaml.safe_load(f)
|
yaml_data = yaml.safe_load(f)
|
||||||
|
|
||||||
new_parameter_rules = []
|
new_parameter_rules = []
|
||||||
for parameter_rule in yaml_data.get('parameter_rules', []):
|
for parameter_rule in yaml_data.get("parameter_rules", []):
|
||||||
if 'use_template' in parameter_rule:
|
if "use_template" in parameter_rule:
|
||||||
try:
|
try:
|
||||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template'])
|
default_parameter_name = DefaultParameterName.value_of(
|
||||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
parameter_rule["use_template"]
|
||||||
|
)
|
||||||
|
default_parameter_rule = (
|
||||||
|
self._get_default_parameter_rule_variable_map(
|
||||||
|
default_parameter_name
|
||||||
|
)
|
||||||
|
)
|
||||||
copy_default_parameter_rule = default_parameter_rule.copy()
|
copy_default_parameter_rule = default_parameter_rule.copy()
|
||||||
copy_default_parameter_rule.update(parameter_rule)
|
copy_default_parameter_rule.update(parameter_rule)
|
||||||
parameter_rule = copy_default_parameter_rule
|
parameter_rule = copy_default_parameter_rule
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if 'label' not in parameter_rule:
|
if "label" not in parameter_rule:
|
||||||
parameter_rule['label'] = {
|
parameter_rule["label"] = {
|
||||||
'zh_Hans': parameter_rule['name'],
|
"zh_Hans": parameter_rule["name"],
|
||||||
'en_US': parameter_rule['name']
|
"en_US": parameter_rule["name"],
|
||||||
}
|
}
|
||||||
|
|
||||||
new_parameter_rules.append(parameter_rule)
|
new_parameter_rules.append(parameter_rule)
|
||||||
|
|
||||||
yaml_data['parameter_rules'] = new_parameter_rules
|
yaml_data["parameter_rules"] = new_parameter_rules
|
||||||
|
|
||||||
if 'label' not in yaml_data:
|
if "label" not in yaml_data:
|
||||||
yaml_data['label'] = {
|
yaml_data["label"] = {
|
||||||
'zh_Hans': yaml_data['model'],
|
"zh_Hans": yaml_data["model"],
|
||||||
'en_US': yaml_data['model']
|
"en_US": yaml_data["model"],
|
||||||
}
|
}
|
||||||
|
|
||||||
yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value
|
yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# yaml_data to entity
|
# yaml_data to entity
|
||||||
model_schema = AIModelEntity(**yaml_data)
|
model_schema = AIModelEntity(**yaml_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
|
model_schema_yaml_file_name = os.path.basename(
|
||||||
raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:'
|
model_schema_yaml_path
|
||||||
f' {str(e)}')
|
).rstrip(".yaml")
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:"
|
||||||
|
f" {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
# cache model schema
|
# cache model schema
|
||||||
model_schemas.append(model_schema)
|
model_schemas.append(model_schema)
|
||||||
|
|
||||||
# resort model schemas by position
|
# resort model schemas by position
|
||||||
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
|
model_schemas = sort_by_position_map(
|
||||||
|
position_map, model_schemas, lambda x: x.model
|
||||||
|
)
|
||||||
|
|
||||||
# cache model schemas
|
# cache model schemas
|
||||||
self.model_schemas = model_schemas
|
self.model_schemas = model_schemas
|
||||||
|
|
||||||
return model_schemas
|
return model_schemas
|
||||||
|
|
||||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
def get_model_schema(
|
||||||
|
self, model: str, credentials: Optional[dict] = None
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
Get model schema by model name and credentials
|
Get model schema by model name and credentials
|
||||||
|
|
||||||
@ -222,13 +257,17 @@ class AIModel(ABC):
|
|||||||
return model_map[model]
|
return model_map[model]
|
||||||
|
|
||||||
if credentials:
|
if credentials:
|
||||||
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
|
model_schema = self.get_customizable_model_schema_from_credentials(
|
||||||
|
model, credentials
|
||||||
|
)
|
||||||
if model_schema:
|
if model_schema:
|
||||||
return model_schema
|
return model_schema
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema_from_credentials(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema from credentials
|
Get customizable model schema from credentials
|
||||||
|
|
||||||
@ -238,7 +277,9 @@ class AIModel(ABC):
|
|||||||
"""
|
"""
|
||||||
return self._get_customizable_model_schema(model, credentials)
|
return self._get_customizable_model_schema(model, credentials)
|
||||||
|
|
||||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def _get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema and fill in the template
|
Get customizable model schema and fill in the template
|
||||||
"""
|
"""
|
||||||
@ -252,26 +293,51 @@ class AIModel(ABC):
|
|||||||
for parameter_rule in schema.parameter_rules:
|
for parameter_rule in schema.parameter_rules:
|
||||||
if parameter_rule.use_template:
|
if parameter_rule.use_template:
|
||||||
try:
|
try:
|
||||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
default_parameter_name = DefaultParameterName.value_of(
|
||||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
parameter_rule.use_template
|
||||||
if not parameter_rule.max and 'max' in default_parameter_rule:
|
)
|
||||||
parameter_rule.max = default_parameter_rule['max']
|
default_parameter_rule = (
|
||||||
if not parameter_rule.min and 'min' in default_parameter_rule:
|
self._get_default_parameter_rule_variable_map(
|
||||||
parameter_rule.min = default_parameter_rule['min']
|
default_parameter_name
|
||||||
if not parameter_rule.default and 'default' in default_parameter_rule:
|
|
||||||
parameter_rule.default = default_parameter_rule['default']
|
|
||||||
if not parameter_rule.precision and 'precision' in default_parameter_rule:
|
|
||||||
parameter_rule.precision = default_parameter_rule['precision']
|
|
||||||
if not parameter_rule.required and 'required' in default_parameter_rule:
|
|
||||||
parameter_rule.required = default_parameter_rule['required']
|
|
||||||
if not parameter_rule.help and 'help' in default_parameter_rule:
|
|
||||||
parameter_rule.help = I18nObject(
|
|
||||||
en_US=default_parameter_rule['help']['en_US'],
|
|
||||||
)
|
)
|
||||||
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
|
)
|
||||||
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
|
if not parameter_rule.max and "max" in default_parameter_rule:
|
||||||
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
|
parameter_rule.max = default_parameter_rule["max"]
|
||||||
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
|
if not parameter_rule.min and "min" in default_parameter_rule:
|
||||||
|
parameter_rule.min = default_parameter_rule["min"]
|
||||||
|
if (
|
||||||
|
not parameter_rule.default
|
||||||
|
and "default" in default_parameter_rule
|
||||||
|
):
|
||||||
|
parameter_rule.default = default_parameter_rule["default"]
|
||||||
|
if (
|
||||||
|
not parameter_rule.precision
|
||||||
|
and "precision" in default_parameter_rule
|
||||||
|
):
|
||||||
|
parameter_rule.precision = default_parameter_rule["precision"]
|
||||||
|
if (
|
||||||
|
not parameter_rule.required
|
||||||
|
and "required" in default_parameter_rule
|
||||||
|
):
|
||||||
|
parameter_rule.required = default_parameter_rule["required"]
|
||||||
|
if not parameter_rule.help and "help" in default_parameter_rule:
|
||||||
|
parameter_rule.help = I18nObject(
|
||||||
|
en_US=default_parameter_rule["help"]["en_US"],
|
||||||
|
)
|
||||||
|
if not parameter_rule.help.en_US and (
|
||||||
|
"help" in default_parameter_rule
|
||||||
|
and "en_US" in default_parameter_rule["help"]
|
||||||
|
):
|
||||||
|
parameter_rule.help.en_US = default_parameter_rule["help"][
|
||||||
|
"en_US"
|
||||||
|
]
|
||||||
|
if not parameter_rule.help.zh_Hans and (
|
||||||
|
"help" in default_parameter_rule
|
||||||
|
and "zh_Hans" in default_parameter_rule["help"]
|
||||||
|
):
|
||||||
|
parameter_rule.help.zh_Hans = default_parameter_rule[
|
||||||
|
"help"
|
||||||
|
].get("zh_Hans", default_parameter_rule["help"]["en_US"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -281,7 +347,9 @@ class AIModel(ABC):
|
|||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema
|
Get customizable model schema
|
||||||
|
|
||||||
@ -291,7 +359,9 @@ class AIModel(ABC):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
|
def _get_default_parameter_rule_variable_map(
|
||||||
|
self, name: DefaultParameterName
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Get default parameter rule for given name
|
Get default parameter rule for given name
|
||||||
|
|
||||||
@ -301,7 +371,7 @@ class AIModel(ABC):
|
|||||||
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
||||||
|
|
||||||
if not default_parameter_rule:
|
if not default_parameter_rule:
|
||||||
raise Exception(f'Invalid model parameter rule name {name}')
|
raise Exception(f"Invalid model parameter rule name {name}")
|
||||||
|
|
||||||
return default_parameter_rule
|
return default_parameter_rule
|
||||||
|
|
||||||
|
|||||||
@ -7,8 +7,16 @@ from collections.abc import Generator
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||||
from model_providers.core.model_runtime.callbacks.logging_callback import LoggingCallback
|
from model_providers.core.model_runtime.callbacks.logging_callback import (
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
LoggingCallback,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMMode,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
LLMUsage,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -32,13 +40,21 @@ class LargeLanguageModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
Model class for large language model.
|
Model class for large language model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType = ModelType.LLM
|
model_type: ModelType = ModelType.LLM
|
||||||
|
|
||||||
def invoke(self, model: str, credentials: dict,
|
def invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: Optional[dict] = None,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -57,7 +73,9 @@ class LargeLanguageModel(AIModel):
|
|||||||
if model_parameters is None:
|
if model_parameters is None:
|
||||||
model_parameters = {}
|
model_parameters = {}
|
||||||
|
|
||||||
model_parameters = self._validate_and_filter_model_parameters(model, model_parameters, credentials)
|
model_parameters = self._validate_and_filter_model_parameters(
|
||||||
|
model, model_parameters, credentials
|
||||||
|
)
|
||||||
|
|
||||||
self.started_at = time.perf_counter()
|
self.started_at = time.perf_counter()
|
||||||
|
|
||||||
@ -76,7 +94,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -90,10 +108,19 @@ class LargeLanguageModel(AIModel):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
result = self._invoke(
|
||||||
|
model,
|
||||||
|
credentials,
|
||||||
|
prompt_messages,
|
||||||
|
model_parameters,
|
||||||
|
tools,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
user,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._trigger_invoke_error_callbacks(
|
self._trigger_invoke_error_callbacks(
|
||||||
model=model,
|
model=model,
|
||||||
@ -105,7 +132,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
@ -121,7 +148,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._trigger_after_invoke_callbacks(
|
self._trigger_after_invoke_callbacks(
|
||||||
@ -134,15 +161,23 @@ class LargeLanguageModel(AIModel):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _code_block_mode_wrapper(
|
||||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
self,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
model: str,
|
||||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
||||||
|
|
||||||
@ -177,36 +212,44 @@ if you are not sure about the structure.
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_parameters.pop("response_format")
|
model_parameters.pop("response_format")
|
||||||
stop = stop or []
|
stop = stop or []
|
||||||
stop.extend(["\n```", "```\n"])
|
stop.extend(["\n```", "```\n"])
|
||||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||||
|
|
||||||
# check if there is a system message
|
# check if there is a system message
|
||||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
if len(prompt_messages) > 0 and isinstance(
|
||||||
|
prompt_messages[0], SystemPromptMessage
|
||||||
|
):
|
||||||
# override the system message
|
# override the system message
|
||||||
prompt_messages[0] = SystemPromptMessage(
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
content=block_prompts
|
content=block_prompts.replace(
|
||||||
.replace("{{instructions}}", prompt_messages[0].content)
|
"{{instructions}}", prompt_messages[0].content
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# insert the system message
|
# insert the system message
|
||||||
prompt_messages.insert(0, SystemPromptMessage(
|
prompt_messages.insert(
|
||||||
content=block_prompts
|
0,
|
||||||
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
SystemPromptMessage(
|
||||||
))
|
content=block_prompts.replace(
|
||||||
|
"{{instructions}}",
|
||||||
|
f"Please output a valid {code_block} object.",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
if len(prompt_messages) > 0 and isinstance(
|
||||||
|
prompt_messages[-1], UserPromptMessage
|
||||||
|
):
|
||||||
# add ```JSON\n to the last message
|
# add ```JSON\n to the last message
|
||||||
prompt_messages[-1].content += f"\n```{code_block}\n"
|
prompt_messages[-1].content += f"\n```{code_block}\n"
|
||||||
else:
|
else:
|
||||||
# append a user message
|
# append a user message
|
||||||
prompt_messages.append(UserPromptMessage(
|
prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n"))
|
||||||
content=f"```{code_block}\n"
|
|
||||||
))
|
|
||||||
|
|
||||||
response = self._invoke(
|
response = self._invoke(
|
||||||
model=model,
|
model=model,
|
||||||
@ -216,33 +259,40 @@ if you are not sure about the structure.
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(response, Generator):
|
if isinstance(response, Generator):
|
||||||
first_chunk = next(response)
|
first_chunk = next(response)
|
||||||
|
|
||||||
def new_generator():
|
def new_generator():
|
||||||
yield first_chunk
|
yield first_chunk
|
||||||
yield from response
|
yield from response
|
||||||
|
|
||||||
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
|
if (
|
||||||
|
first_chunk.delta.message.content
|
||||||
|
and first_chunk.delta.message.content.startswith("`")
|
||||||
|
):
|
||||||
return self._code_block_mode_stream_processor_with_backtick(
|
return self._code_block_mode_stream_processor_with_backtick(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
input_generator=new_generator()
|
input_generator=new_generator(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._code_block_mode_stream_processor(
|
return self._code_block_mode_stream_processor(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
input_generator=new_generator()
|
input_generator=new_generator(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
def _code_block_mode_stream_processor(
|
||||||
input_generator: Generator[LLMResultChunk, None, None]
|
self,
|
||||||
) -> Generator[LLMResultChunk, None, None]:
|
model: str,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
input_generator: Generator[LLMResultChunk, None, None],
|
||||||
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
"""
|
"""
|
||||||
Code block mode stream processor, ensure the response is a code block with output markdown quote
|
Code block mode stream processor, ensure the response is a code block with output markdown quote
|
||||||
|
|
||||||
@ -291,15 +341,17 @@ if you are not sure about the structure.
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=new_piece,
|
content=new_piece, tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
def _code_block_mode_stream_processor_with_backtick(
|
||||||
input_generator: Generator[LLMResultChunk, None, None]) \
|
self,
|
||||||
-> Generator[LLMResultChunk, None, None]:
|
model: str,
|
||||||
|
prompt_messages: list,
|
||||||
|
input_generator: Generator[LLMResultChunk, None, None],
|
||||||
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
"""
|
"""
|
||||||
Code block mode stream processor, ensure the response is a code block with output markdown quote.
|
Code block mode stream processor, ensure the response is a code block with output markdown quote.
|
||||||
This version skips the language identifier that follows the opening triple backticks.
|
This version skips the language identifier that follows the opening triple backticks.
|
||||||
@ -366,26 +418,31 @@ if you are not sure about the structure.
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=new_piece,
|
content=new_piece, tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
def _invoke_result_generator(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
model: str,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
result: Generator,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Invoke result generator
|
Invoke result generator
|
||||||
|
|
||||||
:param result: result generator
|
:param result: result generator
|
||||||
:return: result generator
|
:return: result generator
|
||||||
"""
|
"""
|
||||||
prompt_message = AssistantPromptMessage(
|
prompt_message = AssistantPromptMessage(content="")
|
||||||
content=""
|
|
||||||
)
|
|
||||||
usage = None
|
usage = None
|
||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
real_model = model
|
real_model = model
|
||||||
@ -404,7 +461,7 @@ if you are not sure about the structure.
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_message.content += chunk.delta.message.content
|
prompt_message.content += chunk.delta.message.content
|
||||||
@ -424,7 +481,7 @@ if you are not sure about the structure.
|
|||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=prompt_message,
|
message=prompt_message,
|
||||||
usage=usage if usage else LLMUsage.empty_usage(),
|
usage=usage if usage else LLMUsage.empty_usage(),
|
||||||
system_fingerprint=system_fingerprint
|
system_fingerprint=system_fingerprint,
|
||||||
),
|
),
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -433,15 +490,21 @@ if you are not sure about the structure.
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -456,10 +519,15 @@ if you are not sure about the structure.
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
@ -489,7 +557,9 @@ if you are not sure about the structure.
|
|||||||
for word in result.message.content:
|
for word in result.message.content:
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=word,
|
content=word,
|
||||||
tool_calls=tool_calls if index == (len(result.message.content) - 1) else []
|
tool_calls=tool_calls
|
||||||
|
if index == (len(result.message.content) - 1)
|
||||||
|
else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@ -499,7 +569,7 @@ if you are not sure about the structure.
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
index += 1
|
index += 1
|
||||||
@ -531,11 +601,15 @@ if you are not sure about the structure.
|
|||||||
|
|
||||||
mode = LLMMode.CHAT
|
mode = LLMMode.CHAT
|
||||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||||
mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE])
|
mode = LLMMode.value_of(
|
||||||
|
model_schema.model_properties[ModelPropertyKey.MODE]
|
||||||
|
)
|
||||||
|
|
||||||
return mode
|
return mode
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
|
||||||
|
) -> LLMUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
|
||||||
@ -558,7 +632,7 @@ if you are not sure about the structure.
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.OUTPUT,
|
price_type=PriceType.OUTPUT,
|
||||||
tokens=completion_tokens
|
tokens=completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -572,18 +646,26 @@ if you are not sure about the structure.
|
|||||||
completion_price_unit=completion_price_info.unit,
|
completion_price_unit=completion_price_info.unit,
|
||||||
completion_price=completion_price_info.total_amount,
|
completion_price=completion_price_info.total_amount,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
total_price=prompt_price_info.total_amount
|
||||||
|
+ completion_price_info.total_amount,
|
||||||
currency=prompt_price_info.currency,
|
currency=prompt_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|
||||||
def _trigger_before_invoke_callbacks(self, model: str, credentials: dict,
|
def _trigger_before_invoke_callbacks(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
model: str,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
credentials: dict,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger before invoke callbacks
|
Trigger before invoke callbacks
|
||||||
|
|
||||||
@ -609,19 +691,29 @@ if you are not sure about the structure.
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if callback.raise_error:
|
if callback.raise_error:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
|
logger.warning(
|
||||||
|
f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict,
|
def _trigger_new_chunk_callbacks(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
chunk: LLMResultChunk,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
model: str,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger new chunk callbacks
|
Trigger new chunk callbacks
|
||||||
|
|
||||||
@ -648,19 +740,29 @@ if you are not sure about the structure.
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if callback.raise_error:
|
if callback.raise_error:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
|
logger.warning(
|
||||||
|
f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict,
|
def _trigger_after_invoke_callbacks(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
model: str,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
result: LLMResult,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger after invoke callbacks
|
Trigger after invoke callbacks
|
||||||
|
|
||||||
@ -688,19 +790,29 @@ if you are not sure about the structure.
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if callback.raise_error:
|
if callback.raise_error:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
|
logger.warning(
|
||||||
|
f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict,
|
def _trigger_invoke_error_callbacks(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
model: str,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
ex: Exception,
|
||||||
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Trigger invoke error callbacks
|
Trigger invoke error callbacks
|
||||||
|
|
||||||
@ -728,15 +840,19 @@ if you are not sure about the structure.
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if callback.raise_error:
|
if callback.raise_error:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}")
|
logger.warning(
|
||||||
|
f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _validate_and_filter_model_parameters(self, model: str, model_parameters: dict, credentials: dict) -> dict:
|
def _validate_and_filter_model_parameters(
|
||||||
|
self, model: str, model_parameters: dict, credentials: dict
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Validate model parameters
|
Validate model parameters
|
||||||
|
|
||||||
@ -753,16 +869,23 @@ if you are not sure about the structure.
|
|||||||
parameter_name = parameter_rule.name
|
parameter_name = parameter_rule.name
|
||||||
parameter_value = model_parameters.get(parameter_name)
|
parameter_value = model_parameters.get(parameter_name)
|
||||||
if parameter_value is None:
|
if parameter_value is None:
|
||||||
if parameter_rule.use_template and parameter_rule.use_template in model_parameters:
|
if (
|
||||||
|
parameter_rule.use_template
|
||||||
|
and parameter_rule.use_template in model_parameters
|
||||||
|
):
|
||||||
# if parameter value is None, use template value variable name instead
|
# if parameter value is None, use template value variable name instead
|
||||||
parameter_value = model_parameters[parameter_rule.use_template]
|
parameter_value = model_parameters[parameter_rule.use_template]
|
||||||
else:
|
else:
|
||||||
if parameter_rule.required:
|
if parameter_rule.required:
|
||||||
if parameter_rule.default is not None:
|
if parameter_rule.default is not None:
|
||||||
filtered_model_parameters[parameter_name] = parameter_rule.default
|
filtered_model_parameters[
|
||||||
|
parameter_name
|
||||||
|
] = parameter_rule.default
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model Parameter {parameter_name} is required.")
|
raise ValueError(
|
||||||
|
f"Model Parameter {parameter_name} is required."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -772,47 +895,81 @@ if you are not sure about the structure.
|
|||||||
raise ValueError(f"Model Parameter {parameter_name} should be int.")
|
raise ValueError(f"Model Parameter {parameter_name} should be int.")
|
||||||
|
|
||||||
# validate parameter value range
|
# validate parameter value range
|
||||||
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
|
if (
|
||||||
|
parameter_rule.min is not None
|
||||||
|
and parameter_value < parameter_rule.min
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
|
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
|
||||||
|
)
|
||||||
|
|
||||||
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
|
if (
|
||||||
|
parameter_rule.max is not None
|
||||||
|
and parameter_value > parameter_rule.max
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
|
||||||
|
)
|
||||||
elif parameter_rule.type == ParameterType.FLOAT:
|
elif parameter_rule.type == ParameterType.FLOAT:
|
||||||
if not isinstance(parameter_value, float | int):
|
if not isinstance(parameter_value, float | int):
|
||||||
raise ValueError(f"Model Parameter {parameter_name} should be float.")
|
raise ValueError(
|
||||||
|
f"Model Parameter {parameter_name} should be float."
|
||||||
|
)
|
||||||
|
|
||||||
# validate parameter value precision
|
# validate parameter value precision
|
||||||
if parameter_rule.precision is not None:
|
if parameter_rule.precision is not None:
|
||||||
if parameter_rule.precision == 0:
|
if parameter_rule.precision == 0:
|
||||||
if parameter_value != int(parameter_value):
|
if parameter_value != int(parameter_value):
|
||||||
raise ValueError(f"Model Parameter {parameter_name} should be int.")
|
|
||||||
else:
|
|
||||||
if parameter_value != round(parameter_value, parameter_rule.precision):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.")
|
f"Model Parameter {parameter_name} should be int."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if parameter_value != round(
|
||||||
|
parameter_value, parameter_rule.precision
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places."
|
||||||
|
)
|
||||||
|
|
||||||
# validate parameter value range
|
# validate parameter value range
|
||||||
if parameter_rule.min is not None and parameter_value < parameter_rule.min:
|
if (
|
||||||
|
parameter_rule.min is not None
|
||||||
|
and parameter_value < parameter_rule.min
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.")
|
f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
|
||||||
|
)
|
||||||
|
|
||||||
if parameter_rule.max is not None and parameter_value > parameter_rule.max:
|
if (
|
||||||
|
parameter_rule.max is not None
|
||||||
|
and parameter_value > parameter_rule.max
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.")
|
f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
|
||||||
|
)
|
||||||
elif parameter_rule.type == ParameterType.BOOLEAN:
|
elif parameter_rule.type == ParameterType.BOOLEAN:
|
||||||
if not isinstance(parameter_value, bool):
|
if not isinstance(parameter_value, bool):
|
||||||
raise ValueError(f"Model Parameter {parameter_name} should be bool.")
|
raise ValueError(
|
||||||
|
f"Model Parameter {parameter_name} should be bool."
|
||||||
|
)
|
||||||
elif parameter_rule.type == ParameterType.STRING:
|
elif parameter_rule.type == ParameterType.STRING:
|
||||||
if not isinstance(parameter_value, str):
|
if not isinstance(parameter_value, str):
|
||||||
raise ValueError(f"Model Parameter {parameter_name} should be string.")
|
raise ValueError(
|
||||||
|
f"Model Parameter {parameter_name} should be string."
|
||||||
|
)
|
||||||
|
|
||||||
# validate options
|
# validate options
|
||||||
if parameter_rule.options and parameter_value not in parameter_rule.options:
|
if (
|
||||||
raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
|
parameter_rule.options
|
||||||
|
and parameter_value not in parameter_rule.options
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model Parameter {parameter_name} should be one of {parameter_rule.options}."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported.")
|
raise ValueError(
|
||||||
|
f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
filtered_model_parameters[parameter_name] = parameter_value
|
filtered_model_parameters[parameter_name] = parameter_value
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,10 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
|
from model_providers.core.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
|
||||||
@ -36,24 +39,26 @@ class ModelProvider(ABC):
|
|||||||
return self.provider_schema
|
return self.provider_schema
|
||||||
|
|
||||||
# get dirname of the current path
|
# get dirname of the current path
|
||||||
provider_name = self.__class__.__module__.split('.')[-1]
|
provider_name = self.__class__.__module__.split(".")[-1]
|
||||||
|
|
||||||
# get the path of the model_provider classes
|
# get the path of the model_provider classes
|
||||||
base_path = os.path.abspath(__file__)
|
base_path = os.path.abspath(__file__)
|
||||||
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
|
current_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(base_path)), provider_name
|
||||||
|
)
|
||||||
|
|
||||||
# read provider schema from yaml file
|
# read provider schema from yaml file
|
||||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
yaml_path = os.path.join(current_path, f"{provider_name}.yaml")
|
||||||
yaml_data = {}
|
yaml_data = {}
|
||||||
if os.path.exists(yaml_path):
|
if os.path.exists(yaml_path):
|
||||||
with open(yaml_path, encoding='utf-8') as f:
|
with open(yaml_path, encoding="utf-8") as f:
|
||||||
yaml_data = yaml.safe_load(f)
|
yaml_data = yaml.safe_load(f)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# yaml_data to entity
|
# yaml_data to entity
|
||||||
provider_schema = ProviderEntity(**yaml_data)
|
provider_schema = ProviderEntity(**yaml_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}')
|
raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}")
|
||||||
|
|
||||||
# cache schema
|
# cache schema
|
||||||
self.provider_schema = provider_schema
|
self.provider_schema = provider_schema
|
||||||
@ -88,37 +93,52 @@ class ModelProvider(ABC):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# get dirname of the current path
|
# get dirname of the current path
|
||||||
provider_name = self.__class__.__module__.split('.')[-1]
|
provider_name = self.__class__.__module__.split(".")[-1]
|
||||||
|
|
||||||
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
|
if f"{provider_name}.{model_type.value}" in self.model_instance_map:
|
||||||
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
|
return self.model_instance_map[f"{provider_name}.{model_type.value}"]
|
||||||
|
|
||||||
# get the path of the model type classes
|
# get the path of the model type classes
|
||||||
base_path = os.path.abspath(__file__)
|
base_path = os.path.abspath(__file__)
|
||||||
model_type_name = model_type.value.replace('-', '_')
|
model_type_name = model_type.value.replace("-", "_")
|
||||||
model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
|
model_type_path = os.path.join(
|
||||||
model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py')
|
os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name
|
||||||
|
)
|
||||||
|
model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py")
|
||||||
|
|
||||||
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
|
if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
|
||||||
raise Exception(f'Invalid model type {model_type} for provider {provider_name}')
|
raise Exception(
|
||||||
|
f"Invalid model type {model_type} for provider {provider_name}"
|
||||||
|
)
|
||||||
|
|
||||||
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel
|
||||||
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
|
parent_module = ".".join(self.__class__.__module__.split(".")[:-1])
|
||||||
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path
|
||||||
|
)
|
||||||
mod = importlib.util.module_from_spec(spec)
|
mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
|
|
||||||
model_class = None
|
model_class = None
|
||||||
for name, obj in vars(mod).items():
|
for name, obj in vars(mod).items():
|
||||||
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
|
if (
|
||||||
and obj != AIModel and obj.__module__ == mod.__name__):
|
isinstance(obj, type)
|
||||||
|
and issubclass(obj, AIModel)
|
||||||
|
and not obj.__abstractmethods__
|
||||||
|
and obj != AIModel
|
||||||
|
and obj.__module__ == mod.__name__
|
||||||
|
):
|
||||||
model_class = obj
|
model_class = obj
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model_class:
|
if not model_class:
|
||||||
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
|
raise Exception(
|
||||||
|
f"Missing AIModel Class for model type {model_type} in {model_type_py_path}"
|
||||||
|
)
|
||||||
|
|
||||||
model_instance_map = model_class()
|
model_instance_map = model_class()
|
||||||
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
|
self.model_instance_map[
|
||||||
|
f"{provider_name}.{model_type.value}"
|
||||||
|
] = model_instance_map
|
||||||
|
|
||||||
return model_instance_map
|
return model_instance_map
|
||||||
|
|||||||
@ -10,11 +10,12 @@ class ModerationModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
Model class for moderation model.
|
Model class for moderation model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType = ModelType.MODERATION
|
model_type: ModelType = ModelType.MODERATION
|
||||||
|
|
||||||
def invoke(self, model: str, credentials: dict,
|
def invoke(
|
||||||
text: str, user: Optional[str] = None) \
|
self, model: str, credentials: dict, text: str, user: Optional[str] = None
|
||||||
-> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Invoke moderation model
|
Invoke moderation model
|
||||||
|
|
||||||
@ -32,9 +33,9 @@ class ModerationModel(AIModel):
|
|||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
text: str, user: Optional[str] = None) \
|
self, model: str, credentials: dict, text: str, user: Optional[str] = None
|
||||||
-> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -45,4 +46,3 @@ class ModerationModel(AIModel):
|
|||||||
:return: false if text is safe, true otherwise
|
:return: false if text is safe, true otherwise
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -11,12 +11,19 @@ class RerankModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
Base Model class for rerank model.
|
Base Model class for rerank model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType = ModelType.RERANK
|
model_type: ModelType = ModelType.RERANK
|
||||||
|
|
||||||
def invoke(self, model: str, credentials: dict,
|
def invoke(
|
||||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
self,
|
||||||
user: Optional[str] = None) \
|
model: str,
|
||||||
-> RerankResult:
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
"""
|
"""
|
||||||
Invoke rerank model
|
Invoke rerank model
|
||||||
|
|
||||||
@ -32,15 +39,23 @@ class RerankModel(AIModel):
|
|||||||
self.started_at = time.perf_counter()
|
self.started_at = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
|
return self._invoke(
|
||||||
|
model, credentials, query, docs, score_threshold, top_n, user
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
self,
|
||||||
user: Optional[str] = None) \
|
model: str,
|
||||||
-> RerankResult:
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
"""
|
"""
|
||||||
Invoke rerank model
|
Invoke rerank model
|
||||||
|
|
||||||
|
|||||||
@ -10,11 +10,12 @@ class Speech2TextModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
Model class for speech2text model.
|
Model class for speech2text model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType = ModelType.SPEECH2TEXT
|
model_type: ModelType = ModelType.SPEECH2TEXT
|
||||||
|
|
||||||
def invoke(self, model: str, credentials: dict,
|
def invoke(
|
||||||
file: IO[bytes], user: Optional[str] = None) \
|
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
|
||||||
-> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -30,9 +31,9 @@ class Speech2TextModel(AIModel):
|
|||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
file: IO[bytes], user: Optional[str] = None) \
|
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
|
||||||
-> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -54,4 +55,4 @@ class Speech2TextModel(AIModel):
|
|||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# Construct the path to the audio file
|
# Construct the path to the audio file
|
||||||
return os.path.join(current_dir, 'audio.mp3')
|
return os.path.join(current_dir, "audio.mp3")
|
||||||
|
|||||||
@ -9,11 +9,17 @@ class Text2ImageModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
Model class for text2img model.
|
Model class for text2img model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType = ModelType.TEXT2IMG
|
model_type: ModelType = ModelType.TEXT2IMG
|
||||||
|
|
||||||
def invoke(self, model: str, credentials: dict, prompt: str,
|
def invoke(
|
||||||
model_parameters: dict, user: Optional[str] = None) \
|
self,
|
||||||
-> list[IO[bytes]]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt: str,
|
||||||
|
model_parameters: dict,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> list[IO[bytes]]:
|
||||||
"""
|
"""
|
||||||
Invoke Text2Image model
|
Invoke Text2Image model
|
||||||
|
|
||||||
@ -31,9 +37,14 @@ class Text2ImageModel(AIModel):
|
|||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, model: str, credentials: dict, prompt: str,
|
def _invoke(
|
||||||
model_parameters: dict, user: Optional[str] = None) \
|
self,
|
||||||
-> list[IO[bytes]]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt: str,
|
||||||
|
model_parameters: dict,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> list[IO[bytes]]:
|
||||||
"""
|
"""
|
||||||
Invoke Text2Image model
|
Invoke Text2Image model
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,13 @@ import time
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
|
||||||
|
|
||||||
@ -11,11 +16,16 @@ class TextEmbeddingModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
Model class for text embedding model.
|
Model class for text embedding model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
model_type: ModelType = ModelType.TEXT_EMBEDDING
|
||||||
|
|
||||||
def invoke(self, model: str, credentials: dict,
|
def invoke(
|
||||||
texts: list[str], user: Optional[str] = None) \
|
self,
|
||||||
-> TextEmbeddingResult:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -33,9 +43,13 @@ class TextEmbeddingModel(AIModel):
|
|||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
texts: list[str], user: Optional[str] = None) \
|
self,
|
||||||
-> TextEmbeddingResult:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -69,7 +83,10 @@ class TextEmbeddingModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
|
if (
|
||||||
|
model_schema
|
||||||
|
and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties
|
||||||
|
):
|
||||||
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||||
|
|
||||||
return 1000
|
return 1000
|
||||||
@ -84,7 +101,10 @@ class TextEmbeddingModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
if (
|
||||||
|
model_schema
|
||||||
|
and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
|
||||||
|
):
|
||||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
|||||||
@ -7,27 +7,30 @@ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer
|
|||||||
_tokenizer = None
|
_tokenizer = None
|
||||||
_lock = Lock()
|
_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
class GPT2Tokenizer:
|
class GPT2Tokenizer:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||||
"""
|
"""
|
||||||
use gpt2 tokenizer to get num tokens
|
use gpt2 tokenizer to get num tokens
|
||||||
"""
|
"""
|
||||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||||
tokens = _tokenizer.encode(text, verbose=False)
|
tokens = _tokenizer.encode(text, verbose=False)
|
||||||
return len(tokens)
|
return len(tokens)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_num_tokens(text: str) -> int:
|
def get_num_tokens(text: str) -> int:
|
||||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_encoder() -> Any:
|
def get_encoder() -> Any:
|
||||||
global _tokenizer, _lock
|
global _tokenizer, _lock
|
||||||
with _lock:
|
with _lock:
|
||||||
if _tokenizer is None:
|
if _tokenizer is None:
|
||||||
base_path = abspath(__file__)
|
base_path = abspath(__file__)
|
||||||
gpt2_tokenizer_path = join(dirname(base_path), 'gpt2')
|
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(
|
||||||
|
gpt2_tokenizer_path
|
||||||
|
)
|
||||||
|
|
||||||
return _tokenizer
|
return _tokenizer
|
||||||
|
|||||||
@ -4,7 +4,10 @@ import uuid
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
|
|
||||||
@ -13,10 +16,19 @@ class TTSModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
Model class for ttstext model.
|
Model class for ttstext model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: ModelType = ModelType.TTS
|
model_type: ModelType = ModelType.TTS
|
||||||
|
|
||||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
def invoke(
|
||||||
user: Optional[str] = None):
|
self,
|
||||||
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
streaming: bool,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -31,14 +43,29 @@ class TTSModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._is_ffmpeg_installed()
|
self._is_ffmpeg_installed()
|
||||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
|
return self._invoke(
|
||||||
content_text=content_text, voice=voice, tenant_id=tenant_id)
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
user=user,
|
||||||
|
streaming=streaming,
|
||||||
|
content_text=content_text,
|
||||||
|
voice=voice,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
def _invoke(
|
||||||
user: Optional[str] = None):
|
self,
|
||||||
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
streaming: bool,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -53,7 +80,9 @@ class TTSModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
def get_tts_model_voices(
|
||||||
|
self, model: str, credentials: dict, language: Optional[str] = None
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Get voice for given tts model voices
|
Get voice for given tts model voices
|
||||||
|
|
||||||
@ -67,9 +96,13 @@ class TTSModel(AIModel):
|
|||||||
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
|
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
|
||||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||||
if language:
|
if language:
|
||||||
return [{'name': d['name'], 'value': d['mode']} for d in voices if language and language in d.get('language')]
|
return [
|
||||||
|
{"name": d["name"], "value": d["mode"]}
|
||||||
|
for d in voices
|
||||||
|
if language and language in d.get("language")
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
return [{'name': d['name'], 'value': d['mode']} for d in voices]
|
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||||
|
|
||||||
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||||
"""
|
"""
|
||||||
@ -81,7 +114,10 @@ class TTSModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
|
if (
|
||||||
|
model_schema
|
||||||
|
and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties
|
||||||
|
):
|
||||||
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
|
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
|
||||||
|
|
||||||
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||||
@ -94,7 +130,10 @@ class TTSModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
|
if (
|
||||||
|
model_schema
|
||||||
|
and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties
|
||||||
|
):
|
||||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||||
|
|
||||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||||
@ -104,7 +143,10 @@ class TTSModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
|
if (
|
||||||
|
model_schema
|
||||||
|
and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties
|
||||||
|
):
|
||||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||||
|
|
||||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||||
@ -114,13 +156,16 @@ class TTSModel(AIModel):
|
|||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
|
if (
|
||||||
|
model_schema
|
||||||
|
and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties
|
||||||
|
):
|
||||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
||||||
if delimiters is None:
|
if delimiters is None:
|
||||||
delimiters = set('。!?;\n')
|
delimiters = set("。!?;\n")
|
||||||
|
|
||||||
buf = []
|
buf = []
|
||||||
word_count = 0
|
word_count = 0
|
||||||
@ -128,7 +173,7 @@ class TTSModel(AIModel):
|
|||||||
buf.append(char)
|
buf.append(char)
|
||||||
if char in delimiters:
|
if char in delimiters:
|
||||||
if word_count >= limit:
|
if word_count >= limit:
|
||||||
yield ''.join(buf)
|
yield "".join(buf)
|
||||||
buf = []
|
buf = []
|
||||||
word_count = 0
|
word_count = 0
|
||||||
else:
|
else:
|
||||||
@ -137,7 +182,7 @@ class TTSModel(AIModel):
|
|||||||
word_count += 1
|
word_count += 1
|
||||||
|
|
||||||
if buf:
|
if buf:
|
||||||
yield ''.join(buf)
|
yield "".join(buf)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_ffmpeg_installed():
|
def _is_ffmpeg_installed():
|
||||||
@ -146,13 +191,17 @@ class TTSModel(AIModel):
|
|||||||
if "ffmpeg version" in output.decode("utf-8"):
|
if "ffmpeg version" in output.decode("utf-8"):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise InvokeBadRequestError("ffmpeg is not installed, "
|
raise InvokeBadRequestError(
|
||||||
"details: https://docs.dify.ai/getting-started/install-self-hosted"
|
"ffmpeg is not installed, "
|
||||||
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
|
"details: https://docs.dify.ai/getting-started/install-self-hosted"
|
||||||
|
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech"
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise InvokeBadRequestError("ffmpeg is not installed, "
|
raise InvokeBadRequestError(
|
||||||
"details: https://docs.dify.ai/getting-started/install-self-hosted"
|
"ffmpeg is not installed, "
|
||||||
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
|
"details: https://docs.dify.ai/getting-started/install-self-hosted"
|
||||||
|
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech"
|
||||||
|
)
|
||||||
|
|
||||||
# Todo: To improve the streaming function
|
# Todo: To improve the streaming function
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -160,6 +209,6 @@ class TTSModel(AIModel):
|
|||||||
hash_object = hashlib.sha256(file_content.encode())
|
hash_object = hashlib.sha256(file_content.encode())
|
||||||
hex_digest = hash_object.hexdigest()
|
hex_digest = hash_object.hexdigest()
|
||||||
|
|
||||||
namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31')
|
namespace_uuid = uuid.UUID("a5da6ef9-b303-596f-8e88-bf8fa40f4b31")
|
||||||
unique_uuid = uuid.uuid5(namespace_uuid, hex_digest)
|
unique_uuid = uuid.uuid5(namespace_uuid, hex_digest)
|
||||||
return str(unique_uuid)
|
return str(unique_uuid)
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
from model_providers.core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from model_providers.core.model_runtime.model_providers.model_provider_factory import (
|
||||||
|
ModelProviderFactory,
|
||||||
|
)
|
||||||
|
|
||||||
model_provider_factory = ModelProviderFactory()
|
model_provider_factory = ModelProviderFactory()
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -21,11 +25,12 @@ class AnthropicProvider(ModelProvider):
|
|||||||
|
|
||||||
# Use `claude-instant-1` model for validate,
|
# Use `claude-instant-1` model for validate,
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='claude-instant-1.2',
|
model="claude-instant-1.2", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -18,7 +18,11 @@ from anthropic.types import (
|
|||||||
from httpx import Timeout
|
from httpx import Timeout
|
||||||
|
|
||||||
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
from model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
@ -37,8 +41,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
@ -51,11 +59,17 @@ if you are not sure about the structure.
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -70,11 +84,20 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# invoke model
|
# invoke model
|
||||||
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._chat_generate(
|
||||||
|
model, credentials, prompt_messages, model_parameters, stop, stream, user
|
||||||
|
)
|
||||||
|
|
||||||
def _chat_generate(self, model: str, credentials: dict,
|
def _chat_generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke llm chat model
|
Invoke llm chat model
|
||||||
|
|
||||||
@ -91,23 +114,27 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
|
||||||
# transform model parameters from completion api of anthropic to chat api
|
# transform model parameters from completion api of anthropic to chat api
|
||||||
if 'max_tokens_to_sample' in model_parameters:
|
if "max_tokens_to_sample" in model_parameters:
|
||||||
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
|
model_parameters["max_tokens"] = model_parameters.pop(
|
||||||
|
"max_tokens_to_sample"
|
||||||
|
)
|
||||||
|
|
||||||
# init model client
|
# init model client
|
||||||
client = Anthropic(**credentials_kwargs)
|
client = Anthropic(**credentials_kwargs)
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop_sequences'] = stop
|
extra_model_kwargs["stop_sequences"] = stop
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
|
extra_model_kwargs["metadata"] = completion_create_params.Metadata(
|
||||||
|
user_id=user
|
||||||
|
)
|
||||||
|
|
||||||
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
|
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
extra_model_kwargs['system'] = system
|
extra_model_kwargs["system"] = system
|
||||||
|
|
||||||
# chat model
|
# chat model
|
||||||
response = client.messages.create(
|
response = client.messages.create(
|
||||||
@ -115,22 +142,37 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
messages=prompt_message_dicts,
|
messages=prompt_message_dicts,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
**extra_model_kwargs
|
**extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_chat_generate_stream_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_chat_generate_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _code_block_mode_wrapper(
|
||||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
self,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
model: str,
|
||||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper for invoking large language model
|
Code block mode wrapper for invoking large language model
|
||||||
"""
|
"""
|
||||||
if 'response_format' in model_parameters and model_parameters['response_format']:
|
if (
|
||||||
|
"response_format" in model_parameters
|
||||||
|
and model_parameters["response_format"]
|
||||||
|
):
|
||||||
stop = stop or []
|
stop = stop or []
|
||||||
# chat model
|
# chat model
|
||||||
self._transform_chat_json_prompts(
|
self._transform_chat_json_prompts(
|
||||||
@ -142,17 +184,33 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user,
|
user=user,
|
||||||
response_format=model_parameters['response_format']
|
response_format=model_parameters["response_format"],
|
||||||
)
|
)
|
||||||
model_parameters.pop('response_format')
|
model_parameters.pop("response_format")
|
||||||
|
|
||||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
return self._invoke(
|
||||||
|
model,
|
||||||
|
credentials,
|
||||||
|
prompt_messages,
|
||||||
|
model_parameters,
|
||||||
|
tools,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
|
||||||
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
def _transform_chat_json_prompts(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
model: str,
|
||||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
credentials: dict,
|
||||||
-> None:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
response_format: str = "JSON",
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Transform json prompts
|
Transform json prompts
|
||||||
"""
|
"""
|
||||||
@ -162,25 +220,40 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
stop.append("\n```")
|
stop.append("\n```")
|
||||||
|
|
||||||
# check if there is a system message
|
# check if there is a system message
|
||||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
if len(prompt_messages) > 0 and isinstance(
|
||||||
|
prompt_messages[0], SystemPromptMessage
|
||||||
|
):
|
||||||
# override the system message
|
# override the system message
|
||||||
prompt_messages[0] = SystemPromptMessage(
|
prompt_messages[0] = SystemPromptMessage(
|
||||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
|
||||||
.replace("{{instructions}}", prompt_messages[0].content)
|
"{{instructions}}", prompt_messages[0].content
|
||||||
.replace("{{block}}", response_format)
|
).replace("{{block}}", response_format)
|
||||||
|
)
|
||||||
|
prompt_messages.append(
|
||||||
|
AssistantPromptMessage(content=f"\n```{response_format}")
|
||||||
)
|
)
|
||||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
|
||||||
else:
|
else:
|
||||||
# insert the system message
|
# insert the system message
|
||||||
prompt_messages.insert(0, SystemPromptMessage(
|
prompt_messages.insert(
|
||||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
0,
|
||||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
SystemPromptMessage(
|
||||||
.replace("{{block}}", response_format)
|
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
|
||||||
))
|
"{{instructions}}",
|
||||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
f"Please output a valid {response_format} object.",
|
||||||
|
).replace("{{block}}", response_format)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
prompt_messages.append(
|
||||||
|
AssistantPromptMessage(content=f"\n```{response_format}")
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
@ -214,13 +287,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_tokens": 20,
|
"max_tokens": 20,
|
||||||
},
|
},
|
||||||
stream=False
|
stream=False,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
|
def _handle_chat_generate_response(
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: Message,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm chat response
|
Handle llm chat response
|
||||||
|
|
||||||
@ -243,24 +321,32 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
completion_tokens = self.get_num_tokens(
|
||||||
|
model, credentials, [assistant_prompt_message]
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
model=response.model,
|
model=response.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
def _handle_chat_generate_stream_response(
|
||||||
response: Stream[MessageStreamEvent],
|
self,
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: Stream[MessageStreamEvent],
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm chat stream response
|
Handle llm chat stream response
|
||||||
|
|
||||||
@ -269,7 +355,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response chunk generator
|
:return: llm response chunk generator
|
||||||
"""
|
"""
|
||||||
full_assistant_content = ''
|
full_assistant_content = ""
|
||||||
return_model = None
|
return_model = None
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
output_tokens = 0
|
output_tokens = 0
|
||||||
@ -284,28 +370,26 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
finish_reason = chunk.delta.stop_reason
|
finish_reason = chunk.delta.stop_reason
|
||||||
elif isinstance(chunk, MessageStopEvent):
|
elif isinstance(chunk, MessageStopEvent):
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, input_tokens, output_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=return_model,
|
model=return_model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index + 1,
|
index=index + 1,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(content=""),
|
||||||
content=''
|
|
||||||
),
|
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
elif isinstance(chunk, ContentBlockDeltaEvent):
|
elif isinstance(chunk, ContentBlockDeltaEvent):
|
||||||
chunk_text = chunk.delta.text if chunk.delta.text else ''
|
chunk_text = chunk.delta.text if chunk.delta.text else ""
|
||||||
full_assistant_content += chunk_text
|
full_assistant_content += chunk_text
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=chunk_text)
|
||||||
content=chunk_text
|
|
||||||
)
|
|
||||||
|
|
||||||
index = chunk.index
|
index = chunk.index
|
||||||
|
|
||||||
@ -315,7 +399,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=chunk.index,
|
index=chunk.index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
def _to_credential_kwargs(self, credentials: dict) -> dict:
|
||||||
@ -326,18 +410,22 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
credentials_kwargs = {
|
credentials_kwargs = {
|
||||||
"api_key": credentials['anthropic_api_key'],
|
"api_key": credentials["anthropic_api_key"],
|
||||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||||
"max_retries": 1,
|
"max_retries": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
if 'anthropic_api_url' in credentials and credentials['anthropic_api_url']:
|
if "anthropic_api_url" in credentials and credentials["anthropic_api_url"]:
|
||||||
credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/')
|
credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip(
|
||||||
credentials_kwargs['base_url'] = credentials['anthropic_api_url']
|
"/"
|
||||||
|
)
|
||||||
|
credentials_kwargs["base_url"] = credentials["anthropic_api_url"]
|
||||||
|
|
||||||
return credentials_kwargs
|
return credentials_kwargs
|
||||||
|
|
||||||
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
def _convert_prompt_messages(
|
||||||
|
self, prompt_messages: list[PromptMessage]
|
||||||
|
) -> tuple[str, list[dict]]:
|
||||||
"""
|
"""
|
||||||
Convert prompt messages to dict list and system
|
Convert prompt messages to dict list and system
|
||||||
"""
|
"""
|
||||||
@ -348,7 +436,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
if isinstance(message, SystemPromptMessage):
|
if isinstance(message, SystemPromptMessage):
|
||||||
system += message.content + ("\n" if not system else "")
|
system += message.content + ("\n" if not system else "")
|
||||||
else:
|
else:
|
||||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
prompt_message_dicts.append(
|
||||||
|
self._convert_prompt_message_to_dict(message)
|
||||||
|
)
|
||||||
|
|
||||||
return system, prompt_message_dicts
|
return system, prompt_message_dicts
|
||||||
|
|
||||||
@ -364,38 +454,57 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
sub_messages = []
|
sub_messages = []
|
||||||
for message_content in message.content:
|
for message_content in message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
sub_message_dict = {
|
sub_message_dict = {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": message_content.data
|
"text": message_content.data,
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
ImagePromptMessageContent, message_content
|
||||||
|
)
|
||||||
if not message_content.data.startswith("data:"):
|
if not message_content.data.startswith("data:"):
|
||||||
# fetch image data from url
|
# fetch image data from url
|
||||||
try:
|
try:
|
||||||
image_content = requests.get(message_content.data).content
|
image_content = requests.get(
|
||||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
message_content.data
|
||||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
).content
|
||||||
|
mime_type, _ = mimetypes.guess_type(
|
||||||
|
message_content.data
|
||||||
|
)
|
||||||
|
base64_data = base64.b64encode(image_content).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
raise ValueError(
|
||||||
|
f"Failed to fetch image data from url {message_content.data}, {ex}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
data_split = message_content.data.split(";base64,")
|
data_split = message_content.data.split(";base64,")
|
||||||
mime_type = data_split[0].replace("data:", "")
|
mime_type = data_split[0].replace("data:", "")
|
||||||
base64_data = data_split[1]
|
base64_data = data_split[1]
|
||||||
|
|
||||||
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
if mime_type not in [
|
||||||
raise ValueError(f"Unsupported image type {mime_type}, "
|
"image/jpeg",
|
||||||
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
"image/png",
|
||||||
|
"image/gif",
|
||||||
|
"image/webp",
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported image type {mime_type}, "
|
||||||
|
f"only support image/jpeg, image/png, image/gif, and image/webp"
|
||||||
|
)
|
||||||
|
|
||||||
sub_message_dict = {
|
sub_message_dict = {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"source": {
|
"source": {
|
||||||
"type": "base64",
|
"type": "base64",
|
||||||
"media_type": mime_type,
|
"media_type": mime_type,
|
||||||
"data": base64_data
|
"data": base64_data,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
|
|
||||||
@ -450,7 +559,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
|
def _convert_messages_to_prompt_anthropic(
|
||||||
|
self, messages: list[PromptMessage]
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format a list of messages into a full prompt for the Anthropic model
|
Format a list of messages into a full prompt for the Anthropic model
|
||||||
|
|
||||||
@ -458,15 +569,14 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
messages = messages.copy() # don't mutate the original list
|
messages = messages.copy() # don't mutate the original list
|
||||||
if not isinstance(messages[-1], AssistantPromptMessage):
|
if not isinstance(messages[-1], AssistantPromptMessage):
|
||||||
messages.append(AssistantPromptMessage(content=""))
|
messages.append(AssistantPromptMessage(content=""))
|
||||||
|
|
||||||
text = "".join(
|
text = "".join(
|
||||||
self._convert_one_message_to_text(message)
|
self._convert_one_message_to_text(message) for message in messages
|
||||||
for message in messages
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||||
@ -485,22 +595,18 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [
|
||||||
anthropic.APIConnectionError,
|
anthropic.APIConnectionError,
|
||||||
anthropic.APITimeoutError
|
anthropic.APITimeoutError,
|
||||||
],
|
|
||||||
InvokeServerUnavailableError: [
|
|
||||||
anthropic.InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
anthropic.RateLimitError
|
|
||||||
],
|
],
|
||||||
|
InvokeServerUnavailableError: [anthropic.InternalServerError],
|
||||||
|
InvokeRateLimitError: [anthropic.RateLimitError],
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
anthropic.AuthenticationError,
|
anthropic.AuthenticationError,
|
||||||
anthropic.PermissionDeniedError
|
anthropic.PermissionDeniedError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
anthropic.BadRequestError,
|
anthropic.BadRequestError,
|
||||||
anthropic.NotFoundError,
|
anthropic.NotFoundError,
|
||||||
anthropic.UnprocessableEntityError,
|
anthropic.UnprocessableEntityError,
|
||||||
anthropic.APIError
|
anthropic.APIError,
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,16 +9,18 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION
|
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
|
||||||
|
AZURE_OPENAI_API_VERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _CommonAzureOpenAI:
|
class _CommonAzureOpenAI:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_credential_kwargs(credentials: dict) -> dict:
|
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||||
api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION)
|
api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION)
|
||||||
credentials_kwargs = {
|
credentials_kwargs = {
|
||||||
"api_key": credentials['openai_api_key'],
|
"api_key": credentials["openai_api_key"],
|
||||||
"azure_endpoint": credentials['openai_api_base'],
|
"azure_endpoint": credentials["openai_api_base"],
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||||
"max_retries": 1,
|
"max_retries": 1,
|
||||||
@ -29,24 +31,17 @@ class _CommonAzureOpenAI:
|
|||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||||
openai.APIConnectionError,
|
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||||
openai.APITimeoutError
|
InvokeRateLimitError: [openai.RateLimitError],
|
||||||
],
|
|
||||||
InvokeServerUnavailableError: [
|
|
||||||
openai.InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
openai.RateLimitError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
openai.AuthenticationError,
|
openai.AuthenticationError,
|
||||||
openai.PermissionDeniedError
|
openai.PermissionDeniedError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
openai.BadRequestError,
|
openai.BadRequestError,
|
||||||
openai.NotFoundError,
|
openai.NotFoundError,
|
||||||
openai.UnprocessableEntityError,
|
openai.UnprocessableEntityError,
|
||||||
openai.APIError
|
openai.APIError,
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,11 +14,12 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
PriceConfig,
|
PriceConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
AZURE_OPENAI_API_VERSION = '2024-02-15-preview'
|
AZURE_OPENAI_API_VERSION = "2024-02-15-preview"
|
||||||
|
|
||||||
|
|
||||||
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
||||||
rule = ParameterRule(
|
rule = ParameterRule(
|
||||||
name='max_tokens',
|
name="max_tokens",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS],
|
||||||
)
|
)
|
||||||
rule.default = default
|
rule.default = default
|
||||||
@ -34,11 +35,11 @@ class AzureBaseModel(BaseModel):
|
|||||||
|
|
||||||
LLM_BASE_MODELS = [
|
LLM_BASE_MODELS = [
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='gpt-35-turbo',
|
base_model_name="gpt-35-turbo",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='fake-deployment-name-label',
|
en_US="fake-deployment-name-label",
|
||||||
),
|
),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
features=[
|
features=[
|
||||||
@ -53,37 +54,37 @@ LLM_BASE_MODELS = [
|
|||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='presence_penalty',
|
name="presence_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='frequency_penalty',
|
name="frequency_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||||
),
|
),
|
||||||
_get_max_tokens(default=512, min_val=1, max_val=4096)
|
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.001,
|
input=0.001,
|
||||||
output=0.002,
|
output=0.002,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='gpt-35-turbo-16k',
|
base_model_name="gpt-35-turbo-16k",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='fake-deployment-name-label',
|
en_US="fake-deployment-name-label",
|
||||||
),
|
),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
features=[
|
features=[
|
||||||
@ -98,37 +99,37 @@ LLM_BASE_MODELS = [
|
|||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='presence_penalty',
|
name="presence_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='frequency_penalty',
|
name="frequency_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||||
),
|
),
|
||||||
_get_max_tokens(default=512, min_val=1, max_val=16385)
|
_get_max_tokens(default=512, min_val=1, max_val=16385),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.003,
|
input=0.003,
|
||||||
output=0.004,
|
output=0.004,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='gpt-4',
|
base_model_name="gpt-4",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='fake-deployment-name-label',
|
en_US="fake-deployment-name-label",
|
||||||
),
|
),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
features=[
|
features=[
|
||||||
@ -143,32 +144,29 @@ LLM_BASE_MODELS = [
|
|||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='presence_penalty',
|
name="presence_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='frequency_penalty',
|
name="frequency_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||||
),
|
),
|
||||||
_get_max_tokens(default=512, min_val=1, max_val=8192),
|
_get_max_tokens(default=512, min_val=1, max_val=8192),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='seed',
|
name="seed",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
zh_Hans='种子',
|
type="int",
|
||||||
en_US='Seed'
|
|
||||||
),
|
|
||||||
type='int',
|
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||||
),
|
),
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
@ -176,34 +174,31 @@ LLM_BASE_MODELS = [
|
|||||||
max=1,
|
max=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='response_format',
|
name="response_format",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||||
zh_Hans='回复格式',
|
type="string",
|
||||||
en_US='response_format'
|
|
||||||
),
|
|
||||||
type='string',
|
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
zh_Hans='指定模型必须输出的格式',
|
zh_Hans="指定模型必须输出的格式",
|
||||||
en_US='specifying the format that the model must output'
|
en_US="specifying the format that the model must output",
|
||||||
),
|
),
|
||||||
required=False,
|
required=False,
|
||||||
options=['text', 'json_object']
|
options=["text", "json_object"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.03,
|
input=0.03,
|
||||||
output=0.06,
|
output=0.06,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='gpt-4-32k',
|
base_model_name="gpt-4-32k",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='fake-deployment-name-label',
|
en_US="fake-deployment-name-label",
|
||||||
),
|
),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
features=[
|
features=[
|
||||||
@ -218,32 +213,29 @@ LLM_BASE_MODELS = [
|
|||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='presence_penalty',
|
name="presence_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='frequency_penalty',
|
name="frequency_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||||
),
|
),
|
||||||
_get_max_tokens(default=512, min_val=1, max_val=32768),
|
_get_max_tokens(default=512, min_val=1, max_val=32768),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='seed',
|
name="seed",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
zh_Hans='种子',
|
type="int",
|
||||||
en_US='Seed'
|
|
||||||
),
|
|
||||||
type='int',
|
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||||
),
|
),
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
@ -251,34 +243,31 @@ LLM_BASE_MODELS = [
|
|||||||
max=1,
|
max=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='response_format',
|
name="response_format",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||||
zh_Hans='回复格式',
|
type="string",
|
||||||
en_US='response_format'
|
|
||||||
),
|
|
||||||
type='string',
|
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
zh_Hans='指定模型必须输出的格式',
|
zh_Hans="指定模型必须输出的格式",
|
||||||
en_US='specifying the format that the model must output'
|
en_US="specifying the format that the model must output",
|
||||||
),
|
),
|
||||||
required=False,
|
required=False,
|
||||||
options=['text', 'json_object']
|
options=["text", "json_object"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.06,
|
input=0.06,
|
||||||
output=0.12,
|
output=0.12,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='gpt-4-1106-preview',
|
base_model_name="gpt-4-1106-preview",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='fake-deployment-name-label',
|
en_US="fake-deployment-name-label",
|
||||||
),
|
),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
features=[
|
features=[
|
||||||
@ -293,32 +282,29 @@ LLM_BASE_MODELS = [
|
|||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='presence_penalty',
|
name="presence_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='frequency_penalty',
|
name="frequency_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||||
),
|
),
|
||||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='seed',
|
name="seed",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
zh_Hans='种子',
|
type="int",
|
||||||
en_US='Seed'
|
|
||||||
),
|
|
||||||
type='int',
|
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||||
),
|
),
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
@ -326,39 +312,34 @@ LLM_BASE_MODELS = [
|
|||||||
max=1,
|
max=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='response_format',
|
name="response_format",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||||
zh_Hans='回复格式',
|
type="string",
|
||||||
en_US='response_format'
|
|
||||||
),
|
|
||||||
type='string',
|
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
zh_Hans='指定模型必须输出的格式',
|
zh_Hans="指定模型必须输出的格式",
|
||||||
en_US='specifying the format that the model must output'
|
en_US="specifying the format that the model must output",
|
||||||
),
|
),
|
||||||
required=False,
|
required=False,
|
||||||
options=['text', 'json_object']
|
options=["text", "json_object"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.01,
|
input=0.01,
|
||||||
output=0.03,
|
output=0.03,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='gpt-4-vision-preview',
|
base_model_name="gpt-4-vision-preview",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='fake-deployment-name-label',
|
en_US="fake-deployment-name-label",
|
||||||
),
|
),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
features=[
|
features=[ModelFeature.VISION],
|
||||||
ModelFeature.VISION
|
|
||||||
],
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||||
@ -366,32 +347,29 @@ LLM_BASE_MODELS = [
|
|||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='presence_penalty',
|
name="presence_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='frequency_penalty',
|
name="frequency_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||||
),
|
),
|
||||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='seed',
|
name="seed",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||||
zh_Hans='种子',
|
type="int",
|
||||||
en_US='Seed'
|
|
||||||
),
|
|
||||||
type='int',
|
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。",
|
||||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
en_US="If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.",
|
||||||
),
|
),
|
||||||
required=False,
|
required=False,
|
||||||
precision=2,
|
precision=2,
|
||||||
@ -399,34 +377,31 @@ LLM_BASE_MODELS = [
|
|||||||
max=1,
|
max=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='response_format',
|
name="response_format",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||||
zh_Hans='回复格式',
|
type="string",
|
||||||
en_US='response_format'
|
|
||||||
),
|
|
||||||
type='string',
|
|
||||||
help=I18nObject(
|
help=I18nObject(
|
||||||
zh_Hans='指定模型必须输出的格式',
|
zh_Hans="指定模型必须输出的格式",
|
||||||
en_US='specifying the format that the model must output'
|
en_US="specifying the format that the model must output",
|
||||||
),
|
),
|
||||||
required=False,
|
required=False,
|
||||||
options=['text', 'json_object']
|
options=["text", "json_object"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.01,
|
input=0.01,
|
||||||
output=0.03,
|
output=0.03,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='gpt-35-turbo-instruct',
|
base_model_name="gpt-35-turbo-instruct",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='fake-deployment-name-label',
|
en_US="fake-deployment-name-label",
|
||||||
),
|
),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
@ -436,19 +411,19 @@ LLM_BASE_MODELS = [
|
|||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='presence_penalty',
|
name="presence_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='frequency_penalty',
|
name="frequency_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||||
),
|
),
|
||||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||||
@ -457,16 +432,16 @@ LLM_BASE_MODELS = [
|
|||||||
input=0.0015,
|
input=0.0015,
|
||||||
output=0.002,
|
output=0.002,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='text-davinci-003',
|
base_model_name="text-davinci-003",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(
|
||||||
en_US='fake-deployment-name-label',
|
en_US="fake-deployment-name-label",
|
||||||
),
|
),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
@ -476,19 +451,19 @@ LLM_BASE_MODELS = [
|
|||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='presence_penalty',
|
name="presence_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='frequency_penalty',
|
name="frequency_penalty",
|
||||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||||
),
|
),
|
||||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||||
@ -497,20 +472,18 @@ LLM_BASE_MODELS = [
|
|||||||
input=0.02,
|
input=0.02,
|
||||||
output=0.02,
|
output=0.02,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
EMBEDDING_BASE_MODELS = [
|
EMBEDDING_BASE_MODELS = [
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='text-embedding-ada-002',
|
base_model_name="text-embedding-ada-002",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(en_US="fake-deployment-name-label"),
|
||||||
en_US='fake-deployment-name-label'
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model_properties={
|
model_properties={
|
||||||
@ -520,17 +493,15 @@ EMBEDDING_BASE_MODELS = [
|
|||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.0001,
|
input=0.0001,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='text-embedding-3-small',
|
base_model_name="text-embedding-3-small",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(en_US="fake-deployment-name-label"),
|
||||||
en_US='fake-deployment-name-label'
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model_properties={
|
model_properties={
|
||||||
@ -540,17 +511,15 @@ EMBEDDING_BASE_MODELS = [
|
|||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.00002,
|
input=0.00002,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='text-embedding-3-large',
|
base_model_name="text-embedding-3-large",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(en_US="fake-deployment-name-label"),
|
||||||
en_US='fake-deployment-name-label'
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model_properties={
|
model_properties={
|
||||||
@ -560,135 +529,237 @@ EMBEDDING_BASE_MODELS = [
|
|||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.00013,
|
input=0.00013,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
SPEECH2TEXT_BASE_MODELS = [
|
SPEECH2TEXT_BASE_MODELS = [
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='whisper-1',
|
base_model_name="whisper-1",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(en_US="fake-deployment-name-label"),
|
||||||
en_US='fake-deployment-name-label'
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.SPEECH2TEXT,
|
model_type=ModelType.SPEECH2TEXT,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
|
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
|
||||||
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
|
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: "flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm",
|
||||||
}
|
},
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
TTS_BASE_MODELS = [
|
TTS_BASE_MODELS = [
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='tts-1',
|
base_model_name="tts-1",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(en_US="fake-deployment-name-label"),
|
||||||
en_US='fake-deployment-name-label'
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.TTS,
|
model_type=ModelType.TTS,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
|
ModelPropertyKey.DEFAULT_VOICE: "alloy",
|
||||||
ModelPropertyKey.VOICES: [
|
ModelPropertyKey.VOICES: [
|
||||||
{
|
{
|
||||||
'mode': 'alloy',
|
"mode": "alloy",
|
||||||
'name': 'Alloy',
|
"name": "Alloy",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'echo',
|
"mode": "echo",
|
||||||
'name': 'Echo',
|
"name": "Echo",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'fable',
|
"mode": "fable",
|
||||||
'name': 'Fable',
|
"name": "Fable",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'onyx',
|
"mode": "onyx",
|
||||||
'name': 'Onyx',
|
"name": "Onyx",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'nova',
|
"mode": "nova",
|
||||||
'name': 'Nova',
|
"name": "Nova",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'shimmer',
|
"mode": "shimmer",
|
||||||
'name': 'Shimmer',
|
"name": "Shimmer",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
ModelPropertyKey.WORD_LIMIT: 120,
|
ModelPropertyKey.WORD_LIMIT: 120,
|
||||||
ModelPropertyKey.AUDIO_TYPE: 'mp3',
|
ModelPropertyKey.AUDIO_TYPE: "mp3",
|
||||||
ModelPropertyKey.MAX_WORKERS: 5
|
ModelPropertyKey.MAX_WORKERS: 5,
|
||||||
},
|
},
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.015,
|
input=0.015,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
AzureBaseModel(
|
AzureBaseModel(
|
||||||
base_model_name='tts-1-hd',
|
base_model_name="tts-1-hd",
|
||||||
entity=AIModelEntity(
|
entity=AIModelEntity(
|
||||||
model='fake-deployment-name',
|
model="fake-deployment-name",
|
||||||
label=I18nObject(
|
label=I18nObject(en_US="fake-deployment-name-label"),
|
||||||
en_US='fake-deployment-name-label'
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.TTS,
|
model_type=ModelType.TTS,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
|
ModelPropertyKey.DEFAULT_VOICE: "alloy",
|
||||||
ModelPropertyKey.VOICES: [
|
ModelPropertyKey.VOICES: [
|
||||||
{
|
{
|
||||||
'mode': 'alloy',
|
"mode": "alloy",
|
||||||
'name': 'Alloy',
|
"name": "Alloy",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'echo',
|
"mode": "echo",
|
||||||
'name': 'Echo',
|
"name": "Echo",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'fable',
|
"mode": "fable",
|
||||||
'name': 'Fable',
|
"name": "Fable",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'onyx',
|
"mode": "onyx",
|
||||||
'name': 'Onyx',
|
"name": "Onyx",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'nova',
|
"mode": "nova",
|
||||||
'name': 'Nova',
|
"name": "Nova",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'mode': 'shimmer',
|
"mode": "shimmer",
|
||||||
'name': 'Shimmer',
|
"name": "Shimmer",
|
||||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
"language": [
|
||||||
|
"zh-Hans",
|
||||||
|
"en-US",
|
||||||
|
"de-DE",
|
||||||
|
"fr-FR",
|
||||||
|
"es-ES",
|
||||||
|
"it-IT",
|
||||||
|
"th-TH",
|
||||||
|
"id-ID",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
ModelPropertyKey.WORD_LIMIT: 120,
|
ModelPropertyKey.WORD_LIMIT: 120,
|
||||||
ModelPropertyKey.AUDIO_TYPE: 'mp3',
|
ModelPropertyKey.AUDIO_TYPE: "mp3",
|
||||||
ModelPropertyKey.MAX_WORKERS: 5
|
ModelPropertyKey.MAX_WORKERS: 5,
|
||||||
},
|
},
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=0.03,
|
input=0.03,
|
||||||
unit=0.001,
|
unit=0.001,
|
||||||
currency='USD',
|
currency="USD",
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIProvider(ModelProvider):
|
class AzureOpenAIProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -6,11 +6,23 @@ from typing import Optional, Union, cast
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import AzureOpenAI, Stream
|
from openai import AzureOpenAI, Stream
|
||||||
from openai.types import Completion
|
from openai.types import Completion
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
from openai.types.chat import (
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
ChatCompletion,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
|
ChoiceDeltaFunctionCall,
|
||||||
|
ChoiceDeltaToolCall,
|
||||||
|
)
|
||||||
from openai.types.chat.chat_completion_message import FunctionCall
|
from openai.types.chat.chat_completion_message import FunctionCall
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMMode,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
@ -22,26 +34,47 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
|||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
AIModelEntity,
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
ModelPropertyKey,
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.azure_openai._common import (
|
||||||
|
_CommonAzureOpenAI,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
|
||||||
|
LLM_BASE_MODELS,
|
||||||
|
AzureBaseModel,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
ai_model_entity = self._get_ai_model_entity(
|
||||||
|
credentials.get("base_model_name"), model
|
||||||
|
)
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
if (
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
== LLMMode.CHAT.value
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
):
|
||||||
-> Union[LLMResult, Generator]:
|
|
||||||
|
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
|
||||||
|
|
||||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
|
||||||
# chat model
|
# chat model
|
||||||
return self._chat_generate(
|
return self._chat_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@ -51,7 +84,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# text completion model
|
# text completion model
|
||||||
@ -62,14 +95,19 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
|
credentials: dict,
|
||||||
ModelPropertyKey.MODE)
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
|
model_mode = self._get_ai_model_entity(
|
||||||
|
credentials.get("base_model_name"), model
|
||||||
|
).entity.model_properties.get(ModelPropertyKey.MODE)
|
||||||
|
|
||||||
if model_mode == LLMMode.CHAT.value:
|
if model_mode == LLMMode.CHAT.value:
|
||||||
# chat model
|
# chat model
|
||||||
@ -79,27 +117,36 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
return self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
return self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
if 'openai_api_base' not in credentials:
|
if "openai_api_base" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Azure OpenAI API Base Endpoint is required"
|
||||||
|
)
|
||||||
|
|
||||||
if 'openai_api_key' not in credentials:
|
if "openai_api_key" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
|
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
|
||||||
|
|
||||||
if 'base_model_name' not in credentials:
|
if "base_model_name" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||||
|
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
ai_model_entity = self._get_ai_model_entity(
|
||||||
|
credentials.get("base_model_name"), model
|
||||||
|
)
|
||||||
|
|
||||||
if not ai_model_entity:
|
if not ai_model_entity:
|
||||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
raise CredentialsValidateFailedError(
|
||||||
|
f'Base Model Name {credentials["base_model_name"]} is invalid'
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||||
|
|
||||||
if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
if (
|
||||||
|
ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
||||||
|
== LLMMode.CHAT.value
|
||||||
|
):
|
||||||
# chat model
|
# chat model
|
||||||
client.chat.completions.create(
|
client.chat.completions.create(
|
||||||
messages=[{"role": "user", "content": 'ping'}],
|
messages=[{"role": "user", "content": "ping"}],
|
||||||
model=model,
|
model=model,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
@ -108,7 +155,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
# text completion model
|
# text completion model
|
||||||
client.completions.create(
|
client.completions.create(
|
||||||
prompt='ping',
|
prompt="ping",
|
||||||
model=model,
|
model=model,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
@ -117,23 +164,33 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
|
ai_model_entity = self._get_ai_model_entity(
|
||||||
|
credentials.get("base_model_name"), model
|
||||||
|
)
|
||||||
return ai_model_entity.entity if ai_model_entity else None
|
return ai_model_entity.entity if ai_model_entity else None
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop'] = stop
|
extra_model_kwargs["stop"] = stop
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
# text completion model
|
# text completion model
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
@ -141,22 +198,29 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
model=model,
|
model=model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
**extra_model_kwargs
|
**extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_stream_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
|
def _handle_generate_response(
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: Completion,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> LLMResult:
|
||||||
assistant_text = response.choices[0].text
|
assistant_text = response.choices[0].text
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||||
content=assistant_text
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
if response.usage:
|
if response.usage:
|
||||||
@ -165,11 +229,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
completion_tokens = response.usage.completion_tokens
|
completion_tokens = response.usage.completion_tokens
|
||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
prompt_tokens = self._num_tokens_from_string(
|
||||||
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
credentials, prompt_messages[0].content
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_string(
|
||||||
|
credentials, assistant_text
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
result = LLMResult(
|
result = LLMResult(
|
||||||
@ -182,23 +252,26 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
|
def _handle_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
self,
|
||||||
full_text = ''
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: Stream[Completion],
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator:
|
||||||
|
full_text = ""
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
if delta.finish_reason is None and (delta.text is None or delta.text == ''):
|
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
text = delta.text if delta.text else ''
|
text = delta.text if delta.text else ""
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||||
content=text
|
|
||||||
)
|
|
||||||
|
|
||||||
full_text += text
|
full_text += text
|
||||||
|
|
||||||
@ -210,11 +283,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
completion_tokens = chunk.usage.completion_tokens
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_string(credentials, prompt_messages[0].content)
|
prompt_tokens = self._num_tokens_from_string(
|
||||||
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
credentials, prompt_messages[0].content
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_string(
|
||||||
|
credentials, full_text
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=chunk.model,
|
model=chunk.model,
|
||||||
@ -224,8 +303,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
index=delta.index,
|
index=delta.index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=delta.finish_reason,
|
finish_reason=delta.finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@ -235,14 +314,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=delta.index,
|
index=delta.index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _chat_generate(self, model: str, credentials: dict,
|
def _chat_generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||||
|
|
||||||
response_format = model_parameters.get("response_format")
|
response_format = model_parameters.get("response_format")
|
||||||
@ -258,17 +343,20 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
# extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
||||||
extra_model_kwargs['functions'] = [{
|
extra_model_kwargs["functions"] = [
|
||||||
"name": tool.name,
|
{
|
||||||
"description": tool.description,
|
"name": tool.name,
|
||||||
"parameters": tool.parameters
|
"description": tool.description,
|
||||||
} for tool in tools]
|
"parameters": tool.parameters,
|
||||||
|
}
|
||||||
|
for tool in tools
|
||||||
|
]
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop'] = stop
|
extra_model_kwargs["stop"] = stop
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
# chat model
|
# chat model
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
@ -280,27 +368,36 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
return self._handle_chat_generate_stream_response(
|
||||||
|
model, credentials, response, prompt_messages, tools
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
return self._handle_chat_generate_response(
|
||||||
|
model, credentials, response, prompt_messages, tools
|
||||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
|
)
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> LLMResult:
|
|
||||||
|
|
||||||
|
def _handle_chat_generate_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: ChatCompletion,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> LLMResult:
|
||||||
assistant_message = response.choices[0].message
|
assistant_message = response.choices[0].message
|
||||||
# assistant_message_tool_calls = assistant_message.tool_calls
|
# assistant_message_tool_calls = assistant_message.tool_calls
|
||||||
assistant_message_function_call = assistant_message.function_call
|
assistant_message_function_call = assistant_message.function_call
|
||||||
|
|
||||||
# extract tool calls from response
|
# extract tool calls from response
|
||||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
function_call = self._extract_response_function_call(
|
||||||
|
assistant_message_function_call
|
||||||
|
)
|
||||||
tool_calls = [function_call] if function_call else []
|
tool_calls = [function_call] if function_call else []
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=assistant_message.content,
|
content=assistant_message.content, tool_calls=tool_calls
|
||||||
tool_calls=tool_calls
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
@ -310,11 +407,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
completion_tokens = response.usage.completion_tokens
|
completion_tokens = response.usage.completion_tokens
|
||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
prompt_tokens = self._num_tokens_from_messages(
|
||||||
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
|
credentials, prompt_messages, tools
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
credentials, [assistant_prompt_message]
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
@ -327,24 +430,31 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
def _handle_chat_generate_stream_response(
|
||||||
response: Stream[ChatCompletionChunk],
|
self,
|
||||||
prompt_messages: list[PromptMessage],
|
model: str,
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
credentials: dict,
|
||||||
|
response: Stream[ChatCompletionChunk],
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> Generator:
|
||||||
index = 0
|
index = 0
|
||||||
full_assistant_content = ''
|
full_assistant_content = ""
|
||||||
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
||||||
real_model = model
|
real_model = model
|
||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
completion = ''
|
completion = ""
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
|
if (
|
||||||
delta.delta.function_call is None:
|
delta.finish_reason is None
|
||||||
|
and (delta.delta.content is None or delta.delta.content == "")
|
||||||
|
and delta.delta.function_call is None
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# assistant_message_tool_calls = delta.delta.tool_calls
|
# assistant_message_tool_calls = delta.delta.tool_calls
|
||||||
@ -355,36 +465,44 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
# handle process of stream function call
|
# handle process of stream function call
|
||||||
if assistant_message_function_call:
|
if assistant_message_function_call:
|
||||||
# message has not ended ever
|
# message has not ended ever
|
||||||
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
delta_assistant_message_function_call_storage.arguments += (
|
||||||
|
assistant_message_function_call.arguments
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# message has ended
|
# message has ended
|
||||||
assistant_message_function_call = delta_assistant_message_function_call_storage
|
assistant_message_function_call = (
|
||||||
|
delta_assistant_message_function_call_storage
|
||||||
|
)
|
||||||
delta_assistant_message_function_call_storage = None
|
delta_assistant_message_function_call_storage = None
|
||||||
else:
|
else:
|
||||||
if assistant_message_function_call:
|
if assistant_message_function_call:
|
||||||
# start of stream function call
|
# start of stream function call
|
||||||
delta_assistant_message_function_call_storage = assistant_message_function_call
|
delta_assistant_message_function_call_storage = (
|
||||||
|
assistant_message_function_call
|
||||||
|
)
|
||||||
if delta_assistant_message_function_call_storage.arguments is None:
|
if delta_assistant_message_function_call_storage.arguments is None:
|
||||||
delta_assistant_message_function_call_storage.arguments = ''
|
delta_assistant_message_function_call_storage.arguments = ""
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# extract tool calls from response
|
# extract tool calls from response
|
||||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
function_call = self._extract_response_function_call(
|
||||||
|
assistant_message_function_call
|
||||||
|
)
|
||||||
tool_calls = [function_call] if function_call else []
|
tool_calls = [function_call] if function_call else []
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.delta.content if delta.delta.content else '',
|
content=delta.delta.content if delta.delta.content else "",
|
||||||
tool_calls=tool_calls
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
full_assistant_content += delta.delta.content if delta.delta.content else ""
|
||||||
|
|
||||||
real_model = chunk.model
|
real_model = chunk.model
|
||||||
system_fingerprint = chunk.system_fingerprint
|
system_fingerprint = chunk.system_fingerprint
|
||||||
completion += delta.delta.content if delta.delta.content else ''
|
completion += delta.delta.content if delta.delta.content else ""
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=real_model,
|
model=real_model,
|
||||||
@ -393,21 +511,25 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
index += 0
|
index += 0
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
prompt_tokens = self._num_tokens_from_messages(
|
||||||
|
credentials, prompt_messages, tools
|
||||||
full_assistant_prompt_message = AssistantPromptMessage(
|
)
|
||||||
content=completion
|
|
||||||
|
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
credentials, [full_assistant_prompt_message]
|
||||||
)
|
)
|
||||||
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=real_model,
|
model=real_model,
|
||||||
@ -415,55 +537,52 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
system_fingerprint=system_fingerprint,
|
system_fingerprint=system_fingerprint,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index,
|
||||||
message=AssistantPromptMessage(content=''),
|
message=AssistantPromptMessage(content=""),
|
||||||
finish_reason='stop',
|
finish_reason="stop",
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
|
def _extract_response_tool_calls(
|
||||||
-> list[AssistantPromptMessage.ToolCall]:
|
response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall],
|
||||||
|
) -> list[AssistantPromptMessage.ToolCall]:
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if response_tool_calls:
|
if response_tool_calls:
|
||||||
for response_tool_call in response_tool_calls:
|
for response_tool_call in response_tool_calls:
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_tool_call.function.name,
|
name=response_tool_call.function.name,
|
||||||
arguments=response_tool_call.function.arguments
|
arguments=response_tool_call.function.arguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=response_tool_call.id,
|
id=response_tool_call.id,
|
||||||
type=response_tool_call.type,
|
type=response_tool_call.type,
|
||||||
function=function
|
function=function,
|
||||||
)
|
)
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_response_function_call(response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \
|
def _extract_response_function_call(
|
||||||
-> AssistantPromptMessage.ToolCall:
|
response_function_call: FunctionCall | ChoiceDeltaFunctionCall,
|
||||||
|
) -> AssistantPromptMessage.ToolCall:
|
||||||
tool_call = None
|
tool_call = None
|
||||||
if response_function_call:
|
if response_function_call:
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_function_call.name,
|
name=response_function_call.name,
|
||||||
arguments=response_function_call.arguments
|
arguments=response_function_call.arguments,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=response_function_call.name,
|
id=response_function_call.name, type="function", function=function
|
||||||
type="function",
|
|
||||||
function=function
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_call
|
return tool_call
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
def _convert_prompt_message_to_dict(message: PromptMessage) -> dict:
|
||||||
|
|
||||||
if isinstance(message, UserPromptMessage):
|
if isinstance(message, UserPromptMessage):
|
||||||
message = cast(UserPromptMessage, message)
|
message = cast(UserPromptMessage, message)
|
||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
@ -472,20 +591,24 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
sub_messages = []
|
sub_messages = []
|
||||||
for message_content in message.content:
|
for message_content in message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
sub_message_dict = {
|
sub_message_dict = {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": message_content.data
|
"text": message_content.data,
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
ImagePromptMessageContent, message_content
|
||||||
|
)
|
||||||
sub_message_dict = {
|
sub_message_dict = {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": message_content.data,
|
"url": message_content.data,
|
||||||
"detail": message_content.detail.value
|
"detail": message_content.detail.value,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
|
|
||||||
@ -514,7 +637,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
message_dict = {
|
message_dict = {
|
||||||
"role": "function",
|
"role": "function",
|
||||||
"content": message.content,
|
"content": message.content,
|
||||||
"name": message.tool_call_id
|
"name": message.tool_call_id,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
@ -524,10 +647,14 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
def _num_tokens_from_string(self, credentials: dict, text: str,
|
def _num_tokens_from_string(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
credentials: dict,
|
||||||
|
text: str,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(credentials['base_model_name'])
|
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
@ -538,13 +665,17 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage],
|
def _num_tokens_from_messages(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
credentials: dict,
|
||||||
|
messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
model = credentials['base_model_name']
|
model = credentials["base_model_name"]
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -578,10 +709,10 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
# which need to download the image and then get the resolution for calculation,
|
# which need to download the image and then get the resolution for calculation,
|
||||||
# and will increase the request delay
|
# and will increase the request delay
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
text = ''
|
text = ""
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item['type'] == 'text':
|
if isinstance(item, dict) and item["type"] == "text":
|
||||||
text += item['text']
|
text += item["text"]
|
||||||
|
|
||||||
value = text
|
value = text
|
||||||
|
|
||||||
@ -611,41 +742,42 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
|
def _num_tokens_for_tools(
|
||||||
|
encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
|
||||||
|
) -> int:
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
num_tokens += len(encoding.encode('type'))
|
num_tokens += len(encoding.encode("type"))
|
||||||
num_tokens += len(encoding.encode('function'))
|
num_tokens += len(encoding.encode("function"))
|
||||||
|
|
||||||
# calculate num tokens for function object
|
# calculate num tokens for function object
|
||||||
num_tokens += len(encoding.encode('name'))
|
num_tokens += len(encoding.encode("name"))
|
||||||
num_tokens += len(encoding.encode(tool.name))
|
num_tokens += len(encoding.encode(tool.name))
|
||||||
num_tokens += len(encoding.encode('description'))
|
num_tokens += len(encoding.encode("description"))
|
||||||
num_tokens += len(encoding.encode(tool.description))
|
num_tokens += len(encoding.encode(tool.description))
|
||||||
parameters = tool.parameters
|
parameters = tool.parameters
|
||||||
num_tokens += len(encoding.encode('parameters'))
|
num_tokens += len(encoding.encode("parameters"))
|
||||||
if 'title' in parameters:
|
if "title" in parameters:
|
||||||
num_tokens += len(encoding.encode('title'))
|
num_tokens += len(encoding.encode("title"))
|
||||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||||
num_tokens += len(encoding.encode('type'))
|
num_tokens += len(encoding.encode("type"))
|
||||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||||
if 'properties' in parameters:
|
if "properties" in parameters:
|
||||||
num_tokens += len(encoding.encode('properties'))
|
num_tokens += len(encoding.encode("properties"))
|
||||||
for key, value in parameters.get('properties').items():
|
for key, value in parameters.get("properties").items():
|
||||||
num_tokens += len(encoding.encode(key))
|
num_tokens += len(encoding.encode(key))
|
||||||
for field_key, field_value in value.items():
|
for field_key, field_value in value.items():
|
||||||
num_tokens += len(encoding.encode(field_key))
|
num_tokens += len(encoding.encode(field_key))
|
||||||
if field_key == 'enum':
|
if field_key == "enum":
|
||||||
for enum_field in field_value:
|
for enum_field in field_value:
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
num_tokens += len(encoding.encode(enum_field))
|
num_tokens += len(encoding.encode(enum_field))
|
||||||
else:
|
else:
|
||||||
num_tokens += len(encoding.encode(field_key))
|
num_tokens += len(encoding.encode(field_key))
|
||||||
num_tokens += len(encoding.encode(str(field_value)))
|
num_tokens += len(encoding.encode(str(field_value)))
|
||||||
if 'required' in parameters:
|
if "required" in parameters:
|
||||||
num_tokens += len(encoding.encode('required'))
|
num_tokens += len(encoding.encode("required"))
|
||||||
for required_field in parameters['required']:
|
for required_field in parameters["required"]:
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
num_tokens += len(encoding.encode(required_field))
|
num_tokens += len(encoding.encode(required_field))
|
||||||
|
|
||||||
|
|||||||
@ -4,10 +4,19 @@ from typing import IO, Optional
|
|||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
|
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
|
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import (
|
||||||
|
Speech2TextModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.azure_openai._common import (
|
||||||
|
_CommonAzureOpenAI,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
|
||||||
|
SPEECH2TEXT_BASE_MODELS,
|
||||||
|
AzureBaseModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||||
@ -15,9 +24,9 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
|||||||
Model class for OpenAI Speech to text model.
|
Model class for OpenAI Speech to text model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
file: IO[bytes], user: Optional[str] = None) \
|
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
|
||||||
-> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Invoke speech2text model
|
Invoke speech2text model
|
||||||
|
|
||||||
@ -40,12 +49,14 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
|||||||
try:
|
try:
|
||||||
audio_file_path = self._get_demo_file_path()
|
audio_file_path = self._get_demo_file_path()
|
||||||
|
|
||||||
with open(audio_file_path, 'rb') as audio_file:
|
with open(audio_file_path, "rb") as audio_file:
|
||||||
self._speech2text_invoke(model, credentials, audio_file)
|
self._speech2text_invoke(model, credentials, audio_file)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
|
def _speech2text_invoke(
|
||||||
|
self, model: str, credentials: dict, file: IO[bytes]
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Invoke speech2text model
|
Invoke speech2text model
|
||||||
|
|
||||||
@ -64,11 +75,14 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
|||||||
|
|
||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
|
ai_model_entity = self._get_ai_model_entity(
|
||||||
|
credentials["base_model_name"], model
|
||||||
|
)
|
||||||
return ai_model_entity.entity
|
return ai_model_entity.entity
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||||
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
||||||
|
|||||||
@ -7,28 +7,46 @@ import numpy as np
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, PriceType
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
AIModelEntity,
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
PriceType,
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_BASE_MODELS, AzureBaseModel
|
EmbeddingUsage,
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.azure_openai._common import (
|
||||||
|
_CommonAzureOpenAI,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
|
||||||
|
EMBEDDING_BASE_MODELS,
|
||||||
|
AzureBaseModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||||
|
def _invoke(
|
||||||
def _invoke(self, model: str, credentials: dict,
|
self,
|
||||||
texts: list[str], user: Optional[str] = None) \
|
model: str,
|
||||||
-> TextEmbeddingResult:
|
credentials: dict,
|
||||||
base_model_name = credentials['base_model_name']
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
|
base_model_name = credentials["base_model_name"]
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
client = AzureOpenAI(**credentials_kwargs)
|
client = AzureOpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
extra_model_kwargs['encoding_format'] = 'base64'
|
extra_model_kwargs["encoding_format"] = "base64"
|
||||||
|
|
||||||
context_size = self._get_context_size(model, credentials)
|
context_size = self._get_context_size(model, credentials)
|
||||||
max_chunks = self._get_max_chunks(model, credentials)
|
max_chunks = self._get_max_chunks(model, credentials)
|
||||||
@ -44,11 +62,9 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|||||||
enc = tiktoken.get_encoding("cl100k_base")
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
token = enc.encode(
|
token = enc.encode(text)
|
||||||
text
|
|
||||||
)
|
|
||||||
for j in range(0, len(token), context_size):
|
for j in range(0, len(token), context_size):
|
||||||
tokens += [token[j: j + context_size]]
|
tokens += [token[j : j + context_size]]
|
||||||
indices += [i]
|
indices += [i]
|
||||||
|
|
||||||
batched_embeddings = []
|
batched_embeddings = []
|
||||||
@ -58,8 +74,8 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|||||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
model=model,
|
model=model,
|
||||||
client=client,
|
client=client,
|
||||||
texts=tokens[i: i + max_chunks],
|
texts=tokens[i : i + max_chunks],
|
||||||
extra_model_kwargs=extra_model_kwargs
|
extra_model_kwargs=extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
used_tokens += embedding_used_tokens
|
||||||
@ -78,7 +94,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
client=client,
|
client=client,
|
||||||
texts="",
|
texts="",
|
||||||
extra_model_kwargs=extra_model_kwargs
|
extra_model_kwargs=extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
used_tokens += embedding_used_tokens
|
||||||
@ -89,15 +105,11 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|||||||
|
|
||||||
# calc usage
|
# calc usage
|
||||||
usage = self._calc_response_usage(
|
usage = self._calc_response_usage(
|
||||||
model=model,
|
model=model, credentials=credentials, tokens=used_tokens
|
||||||
credentials=credentials,
|
|
||||||
tokens=used_tokens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
return TextEmbeddingResult(
|
||||||
embeddings=embeddings,
|
embeddings=embeddings, usage=usage, model=base_model_name
|
||||||
usage=usage,
|
|
||||||
model=base_model_name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
@ -105,7 +117,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
enc = tiktoken.encoding_for_model(credentials['base_model_name'])
|
enc = tiktoken.encoding_for_model(credentials["base_model_name"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
@ -118,57 +130,78 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|||||||
return total_num_tokens
|
return total_num_tokens
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
if 'openai_api_base' not in credentials:
|
if "openai_api_base" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Azure OpenAI API Base Endpoint is required"
|
||||||
|
)
|
||||||
|
|
||||||
if 'openai_api_key' not in credentials:
|
if "openai_api_key" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Azure OpenAI API key is required')
|
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
|
||||||
|
|
||||||
if 'base_model_name' not in credentials:
|
if "base_model_name" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Base Model Name is required')
|
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||||
|
|
||||||
if not self._get_ai_model_entity(credentials['base_model_name'], model):
|
if not self._get_ai_model_entity(credentials["base_model_name"], model):
|
||||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
raise CredentialsValidateFailedError(
|
||||||
|
f'Base Model Name {credentials["base_model_name"]} is invalid'
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
client = AzureOpenAI(**credentials_kwargs)
|
client = AzureOpenAI(**credentials_kwargs)
|
||||||
|
|
||||||
self._embedding_invoke(
|
self._embedding_invoke(
|
||||||
model=model,
|
model=model, client=client, texts=["ping"], extra_model_kwargs={}
|
||||||
client=client,
|
|
||||||
texts=['ping'],
|
|
||||||
extra_model_kwargs={}
|
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
|
ai_model_entity = self._get_ai_model_entity(
|
||||||
|
credentials["base_model_name"], model
|
||||||
|
)
|
||||||
return ai_model_entity.entity
|
return ai_model_entity.entity
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str],
|
def _embedding_invoke(
|
||||||
extra_model_kwargs: dict) -> tuple[list[list[float]], int]:
|
model: str,
|
||||||
|
client: AzureOpenAI,
|
||||||
|
texts: Union[list[str], str],
|
||||||
|
extra_model_kwargs: dict,
|
||||||
|
) -> tuple[list[list[float]], int]:
|
||||||
response = client.embeddings.create(
|
response = client.embeddings.create(
|
||||||
input=texts,
|
input=texts,
|
||||||
model=model,
|
model=model,
|
||||||
**extra_model_kwargs,
|
**extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64':
|
if (
|
||||||
|
"encoding_format" in extra_model_kwargs
|
||||||
|
and extra_model_kwargs["encoding_format"] == "base64"
|
||||||
|
):
|
||||||
# decode base64 embedding
|
# decode base64 embedding
|
||||||
return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
|
return (
|
||||||
response.usage.total_tokens)
|
[
|
||||||
|
list(
|
||||||
|
np.frombuffer(base64.b64decode(data.embedding), dtype="float32")
|
||||||
|
)
|
||||||
|
for data in response.data
|
||||||
|
],
|
||||||
|
response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return [data.embedding for data in response.data], response.usage.total_tokens
|
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
input_price_info = self.get_price(
|
input_price_info = self.get_price(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -179,7 +212,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|||||||
@ -3,16 +3,24 @@ import copy
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
|
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
from model_providers.core.model_runtime.model_providers.azure_openai._common import (
|
||||||
from model_providers.core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
|
_CommonAzureOpenAI,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.azure_openai._constant import (
|
||||||
|
TTS_BASE_MODELS,
|
||||||
|
AzureBaseModel,
|
||||||
|
)
|
||||||
from model_providers.extensions.ext_storage import storage
|
from model_providers.extensions.ext_storage import storage
|
||||||
|
|
||||||
|
|
||||||
@ -21,8 +29,16 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
Model class for OpenAI Speech to text model.
|
Model class for OpenAI Speech to text model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
def _invoke(
|
||||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
self,
|
||||||
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
streaming: bool,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> any:
|
||||||
"""
|
"""
|
||||||
_invoke text2speech model
|
_invoke text2speech model
|
||||||
|
|
||||||
@ -36,20 +52,34 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
:return: text translated to audio file
|
:return: text translated to audio file
|
||||||
"""
|
"""
|
||||||
audio_type = self._get_model_audio_type(model, credentials)
|
audio_type = self._get_model_audio_type(model, credentials)
|
||||||
if not voice or voice not in [d['value'] for d in
|
if not voice or voice not in [
|
||||||
self.get_tts_model_voices(model=model, credentials=credentials)]:
|
d["value"]
|
||||||
|
for d in self.get_tts_model_voices(model=model, credentials=credentials)
|
||||||
|
]:
|
||||||
voice = self._get_model_default_voice(model, credentials)
|
voice = self._get_model_default_voice(model, credentials)
|
||||||
if streaming:
|
if streaming:
|
||||||
return StreamingResponse(self._tts_invoke_streaming(model=model,
|
return StreamingResponse(
|
||||||
credentials=credentials,
|
self._tts_invoke_streaming(
|
||||||
content_text=content_text,
|
model=model,
|
||||||
tenant_id=tenant_id,
|
credentials=credentials,
|
||||||
voice=voice), media_type='text/event-stream')
|
content_text=content_text,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
voice=voice,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
return self._tts_invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
content_text=content_text,
|
||||||
|
voice=voice,
|
||||||
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
def validate_credentials(
|
||||||
|
self, model: str, credentials: dict, user: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
validate credentials text2speech model
|
validate credentials text2speech model
|
||||||
|
|
||||||
@ -62,13 +92,15 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
self._tts_invoke(
|
self._tts_invoke(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
content_text='Hello Dify!',
|
content_text="Hello Dify!",
|
||||||
voice=self._get_model_default_voice(model, credentials),
|
voice=self._get_model_default_voice(model, credentials),
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse:
|
def _tts_invoke(
|
||||||
|
self, model: str, credentials: dict, content_text: str, voice: str
|
||||||
|
) -> StreamingResponse:
|
||||||
"""
|
"""
|
||||||
_tts_invoke text2speech model
|
_tts_invoke text2speech model
|
||||||
|
|
||||||
@ -82,13 +114,25 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
word_limit = self._get_model_word_limit(model, credentials)
|
word_limit = self._get_model_word_limit(model, credentials)
|
||||||
max_workers = self._get_model_workers_limit(model, credentials)
|
max_workers = self._get_model_workers_limit(model, credentials)
|
||||||
try:
|
try:
|
||||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
sentences = list(
|
||||||
|
self._split_text_into_sentences(text=content_text, limit=word_limit)
|
||||||
|
)
|
||||||
audio_bytes_list = list()
|
audio_bytes_list = list()
|
||||||
|
|
||||||
# Create a thread pool and map the function to the list of sentences
|
# Create a thread pool and map the function to the list of sentences
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(
|
||||||
futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
|
max_workers=max_workers
|
||||||
credentials=credentials) for sentence in sentences]
|
) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
self._process_sentence,
|
||||||
|
sentence=sentence,
|
||||||
|
model=model,
|
||||||
|
voice=voice,
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
for sentence in sentences
|
||||||
|
]
|
||||||
for future in futures:
|
for future in futures:
|
||||||
try:
|
try:
|
||||||
if future.result():
|
if future.result():
|
||||||
@ -97,8 +141,11 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
if len(audio_bytes_list) > 0:
|
if len(audio_bytes_list) > 0:
|
||||||
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
audio_segments = [
|
||||||
audio_bytes_list if audio_bytes]
|
AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type)
|
||||||
|
for audio_bytes in audio_bytes_list
|
||||||
|
if audio_bytes
|
||||||
|
]
|
||||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||||
buffer: BytesIO = BytesIO()
|
buffer: BytesIO = BytesIO()
|
||||||
combined_segment.export(buffer, format=audio_type)
|
combined_segment.export(buffer, format=audio_type)
|
||||||
@ -108,8 +155,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
# Todo: To improve the streaming function
|
# Todo: To improve the streaming function
|
||||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
def _tts_invoke_streaming(
|
||||||
voice: str) -> any:
|
self,
|
||||||
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
) -> any:
|
||||||
"""
|
"""
|
||||||
_tts_invoke_streaming text2speech model
|
_tts_invoke_streaming text2speech model
|
||||||
|
|
||||||
@ -122,24 +175,29 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
"""
|
"""
|
||||||
# transform credentials to kwargs for model instance
|
# transform credentials to kwargs for model instance
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
if not voice or voice not in self.get_tts_model_voices(
|
||||||
|
model=model, credentials=credentials
|
||||||
|
):
|
||||||
voice = self._get_model_default_voice(model, credentials)
|
voice = self._get_model_default_voice(model, credentials)
|
||||||
word_limit = self._get_model_word_limit(model, credentials)
|
word_limit = self._get_model_word_limit(model, credentials)
|
||||||
audio_type = self._get_model_audio_type(model, credentials)
|
audio_type = self._get_model_audio_type(model, credentials)
|
||||||
tts_file_id = self._get_file_name(content_text)
|
tts_file_id = self._get_file_name(content_text)
|
||||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
file_path = f"generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}"
|
||||||
try:
|
try:
|
||||||
client = AzureOpenAI(**credentials_kwargs)
|
client = AzureOpenAI(**credentials_kwargs)
|
||||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
sentences = list(
|
||||||
|
self._split_text_into_sentences(text=content_text, limit=word_limit)
|
||||||
|
)
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
response = client.audio.speech.create(
|
||||||
|
model=model, voice=voice, input=sentence.strip()
|
||||||
|
)
|
||||||
# response.stream_to_file(file_path)
|
# response.stream_to_file(file_path)
|
||||||
storage.save(file_path, response.read())
|
storage.save(file_path, response.read())
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
def _process_sentence(self, sentence: str, model: str,
|
def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
|
||||||
voice, credentials: dict):
|
|
||||||
"""
|
"""
|
||||||
_tts_invoke openai text2speech model api
|
_tts_invoke openai text2speech model api
|
||||||
|
|
||||||
@ -152,12 +210,18 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
|||||||
# transform credentials to kwargs for model instance
|
# transform credentials to kwargs for model instance
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
client = AzureOpenAI(**credentials_kwargs)
|
client = AzureOpenAI(**credentials_kwargs)
|
||||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
response = client.audio.speech.create(
|
||||||
|
model=model, voice=voice, input=sentence.strip()
|
||||||
|
)
|
||||||
if isinstance(response.read(), bytes):
|
if isinstance(response.read(), bytes):
|
||||||
return response.read()
|
return response.read()
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(
|
||||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
|
ai_model_entity = self._get_ai_model_entity(
|
||||||
|
credentials["base_model_name"], model
|
||||||
|
)
|
||||||
return ai_model_entity.entity
|
return ai_model_entity.entity
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaichuanProvider(ModelProvider):
|
class BaichuanProvider(ModelProvider):
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -20,11 +25,12 @@ class BaichuanProvider(ModelProvider):
|
|||||||
|
|
||||||
# Use `baichuan2-turbo` model for validate,
|
# Use `baichuan2-turbo` model for validate,
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='baichuan2-turbo',
|
model="baichuan2-turbo", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -4,17 +4,20 @@ import re
|
|||||||
class BaichuanTokenizer:
|
class BaichuanTokenizer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def count_chinese_characters(cls, text: str) -> int:
|
def count_chinese_characters(cls, text: str) -> int:
|
||||||
return len(re.findall(r'[\u4e00-\u9fa5]', text))
|
return len(re.findall(r"[\u4e00-\u9fa5]", text))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def count_english_vocabularies(cls, text: str) -> int:
|
def count_english_vocabularies(cls, text: str) -> int:
|
||||||
# remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc.
|
# remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc.
|
||||||
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
|
||||||
# count the number of words not characters
|
# count the number of words not characters
|
||||||
return len(text.split())
|
return len(text.split())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_num_tokens(cls, text: str) -> int:
|
def _get_num_tokens(cls, text: str) -> int:
|
||||||
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
|
# tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return)
|
||||||
# https://platform.baichuan-ai.com/docs/text-Embedding
|
# https://platform.baichuan-ai.com/docs/text-Embedding
|
||||||
return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3)
|
return int(
|
||||||
|
cls.count_chinese_characters(text)
|
||||||
|
+ cls.count_english_vocabularies(text) * 1.3
|
||||||
|
)
|
||||||
|
|||||||
@ -18,153 +18,188 @@ from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tu
|
|||||||
|
|
||||||
class BaichuanMessage:
|
class BaichuanMessage:
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
USER = 'user'
|
USER = "user"
|
||||||
ASSISTANT = 'assistant'
|
ASSISTANT = "assistant"
|
||||||
# Baichuan does not have system message
|
# Baichuan does not have system message
|
||||||
_SYSTEM = 'system'
|
_SYSTEM = "system"
|
||||||
|
|
||||||
role: str = Role.USER.value
|
role: str = Role.USER.value
|
||||||
content: str
|
content: str
|
||||||
usage: dict[str, int] = None
|
usage: dict[str, int] = None
|
||||||
stop_reason: str = ''
|
stop_reason: str = ""
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
'role': self.role,
|
"role": self.role,
|
||||||
'content': self.content,
|
"content": self.content,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, content: str, role: str = 'user') -> None:
|
def __init__(self, content: str, role: str = "user") -> None:
|
||||||
self.content = content
|
self.content = content
|
||||||
self.role = role
|
self.role = role
|
||||||
|
|
||||||
|
|
||||||
class BaichuanModel:
|
class BaichuanModel:
|
||||||
api_key: str
|
api_key: str
|
||||||
secret_key: str
|
secret_key: str
|
||||||
|
|
||||||
def __init__(self, api_key: str, secret_key: str = '') -> None:
|
def __init__(self, api_key: str, secret_key: str = "") -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.secret_key = secret_key
|
self.secret_key = secret_key
|
||||||
|
|
||||||
def _model_mapping(self, model: str) -> str:
|
def _model_mapping(self, model: str) -> str:
|
||||||
return {
|
return {
|
||||||
'baichuan2-turbo': 'Baichuan2-Turbo',
|
"baichuan2-turbo": "Baichuan2-Turbo",
|
||||||
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
|
"baichuan2-turbo-192k": "Baichuan2-Turbo-192k",
|
||||||
'baichuan2-53b': 'Baichuan2-53B',
|
"baichuan2-53b": "Baichuan2-53B",
|
||||||
}[model]
|
}[model]
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
choices = resp.get('choices', [])
|
choices = resp.get("choices", [])
|
||||||
message = BaichuanMessage(content='', role='assistant')
|
message = BaichuanMessage(content="", role="assistant")
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
message.content += choice['message']['content']
|
message.content += choice["message"]["content"]
|
||||||
message.role = choice['message']['role']
|
message.role = choice["message"]["role"]
|
||||||
if choice['finish_reason']:
|
if choice["finish_reason"]:
|
||||||
message.stop_reason = choice['finish_reason']
|
message.stop_reason = choice["finish_reason"]
|
||||||
|
|
||||||
|
if "usage" in resp:
|
||||||
|
message.usage = {
|
||||||
|
"prompt_tokens": resp["usage"]["prompt_tokens"],
|
||||||
|
"completion_tokens": resp["usage"]["completion_tokens"],
|
||||||
|
"total_tokens": resp["usage"]["total_tokens"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
if 'usage' in resp:
|
|
||||||
message.usage = {
|
|
||||||
'prompt_tokens': resp['usage']['prompt_tokens'],
|
|
||||||
'completion_tokens': resp['usage']['completion_tokens'],
|
|
||||||
'total_tokens': resp['usage']['total_tokens'],
|
|
||||||
}
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
def _handle_chat_stream_generate_response(self, response) -> Generator:
|
def _handle_chat_stream_generate_response(self, response) -> Generator:
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
line = line.decode('utf-8')
|
line = line.decode("utf-8")
|
||||||
# remove the first `data: ` prefix
|
# remove the first `data: ` prefix
|
||||||
if line.startswith('data:'):
|
if line.startswith("data:"):
|
||||||
line = line[5:].strip()
|
line = line[5:].strip()
|
||||||
try:
|
try:
|
||||||
data = loads(line)
|
data = loads(line)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if line.strip() == '[DONE]':
|
if line.strip() == "[DONE]":
|
||||||
return
|
return
|
||||||
choices = data.get('choices', [])
|
choices = data.get("choices", [])
|
||||||
# save stop reason temporarily
|
# save stop reason temporarily
|
||||||
stop_reason = ''
|
stop_reason = ""
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
if 'finish_reason' in choice and choice['finish_reason']:
|
if "finish_reason" in choice and choice["finish_reason"]:
|
||||||
stop_reason = choice['finish_reason']
|
stop_reason = choice["finish_reason"]
|
||||||
|
|
||||||
if len(choice['delta']['content']) == 0:
|
if len(choice["delta"]["content"]) == 0:
|
||||||
continue
|
continue
|
||||||
yield BaichuanMessage(**choice['delta'])
|
yield BaichuanMessage(**choice["delta"])
|
||||||
|
|
||||||
# if there is usage, the response is the last one, yield it and return
|
# if there is usage, the response is the last one, yield it and return
|
||||||
if 'usage' in data:
|
if "usage" in data:
|
||||||
message = BaichuanMessage(content='', role='assistant')
|
message = BaichuanMessage(content="", role="assistant")
|
||||||
message.usage = {
|
message.usage = {
|
||||||
'prompt_tokens': data['usage']['prompt_tokens'],
|
"prompt_tokens": data["usage"]["prompt_tokens"],
|
||||||
'completion_tokens': data['usage']['completion_tokens'],
|
"completion_tokens": data["usage"]["completion_tokens"],
|
||||||
'total_tokens': data['usage']['total_tokens'],
|
"total_tokens": data["usage"]["total_tokens"],
|
||||||
}
|
}
|
||||||
message.stop_reason = stop_reason
|
message.stop_reason = stop_reason
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
def _build_parameters(
|
||||||
parameters: dict[str, Any]) \
|
self,
|
||||||
-> dict[str, Any]:
|
model: str,
|
||||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
stream: bool,
|
||||||
|
messages: list[BaichuanMessage],
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if (
|
||||||
|
model == "baichuan2-turbo"
|
||||||
|
or model == "baichuan2-turbo-192k"
|
||||||
|
or model == "baichuan2-53b"
|
||||||
|
):
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
|
if (
|
||||||
|
message.role == BaichuanMessage.Role.USER.value
|
||||||
|
or message.role == BaichuanMessage.Role._SYSTEM.value
|
||||||
|
):
|
||||||
# check if the latest message is a user message
|
# check if the latest message is a user message
|
||||||
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
|
if (
|
||||||
prompt_messages[-1]['content'] += message.content
|
len(prompt_messages) > 0
|
||||||
|
and prompt_messages[-1]["role"]
|
||||||
|
== BaichuanMessage.Role.USER.value
|
||||||
|
):
|
||||||
|
prompt_messages[-1]["content"] += message.content
|
||||||
else:
|
else:
|
||||||
prompt_messages.append({
|
prompt_messages.append(
|
||||||
'content': message.content,
|
{
|
||||||
'role': BaichuanMessage.Role.USER.value,
|
"content": message.content,
|
||||||
})
|
"role": BaichuanMessage.Role.USER.value,
|
||||||
|
}
|
||||||
|
)
|
||||||
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
|
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
|
||||||
prompt_messages.append({
|
prompt_messages.append(
|
||||||
'content': message.content,
|
{
|
||||||
'role': message.role,
|
"content": message.content,
|
||||||
})
|
"role": message.role,
|
||||||
|
}
|
||||||
|
)
|
||||||
# [baichuan] frequency_penalty must be between 1 and 2
|
# [baichuan] frequency_penalty must be between 1 and 2
|
||||||
if 'frequency_penalty' in parameters:
|
if "frequency_penalty" in parameters:
|
||||||
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
|
if (
|
||||||
parameters['frequency_penalty'] = 1
|
parameters["frequency_penalty"] < 1
|
||||||
|
or parameters["frequency_penalty"] > 2
|
||||||
|
):
|
||||||
|
parameters["frequency_penalty"] = 1
|
||||||
|
|
||||||
# turbo api accepts flat parameters
|
# turbo api accepts flat parameters
|
||||||
return {
|
return {
|
||||||
'model': self._model_mapping(model),
|
"model": self._model_mapping(model),
|
||||||
'stream': stream,
|
"stream": stream,
|
||||||
'messages': prompt_messages,
|
"messages": prompt_messages,
|
||||||
**parameters,
|
**parameters,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
raise BadRequestError(f"Unknown model: {model}")
|
||||||
|
|
||||||
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
if (
|
||||||
|
model == "baichuan2-turbo"
|
||||||
|
or model == "baichuan2-turbo-192k"
|
||||||
|
or model == "baichuan2-53b"
|
||||||
|
):
|
||||||
# there is no secret key for turbo api
|
# there is no secret key for turbo api
|
||||||
return {
|
return {
|
||||||
'Content-Type': 'application/json',
|
"Content-Type": "application/json",
|
||||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ",
|
||||||
'Authorization': 'Bearer ' + self.api_key,
|
"Authorization": "Bearer " + self.api_key,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
raise BadRequestError(f"Unknown model: {model}")
|
||||||
|
|
||||||
def _calculate_md5(self, input_string):
|
|
||||||
return md5(input_string.encode('utf-8')).hexdigest()
|
|
||||||
|
|
||||||
def generate(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
def _calculate_md5(self, input_string):
|
||||||
parameters: dict[str, Any], timeout: int) \
|
return md5(input_string.encode("utf-8")).hexdigest()
|
||||||
-> Union[Generator, BaichuanMessage]:
|
|
||||||
|
def generate(
|
||||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
self,
|
||||||
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
|
model: str,
|
||||||
|
stream: bool,
|
||||||
|
messages: list[BaichuanMessage],
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
timeout: int,
|
||||||
|
) -> Union[Generator, BaichuanMessage]:
|
||||||
|
if (
|
||||||
|
model == "baichuan2-turbo"
|
||||||
|
or model == "baichuan2-turbo-192k"
|
||||||
|
or model == "baichuan2-53b"
|
||||||
|
):
|
||||||
|
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||||
else:
|
else:
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
raise BadRequestError(f"Unknown model: {model}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = self._build_parameters(model, stream, messages, parameters)
|
data = self._build_parameters(model, stream, messages, parameters)
|
||||||
headers = self._build_headers(model, data)
|
headers = self._build_headers(model, data)
|
||||||
@ -177,35 +212,37 @@ class BaichuanModel:
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
data=dumps(data),
|
data=dumps(data),
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
stream=stream
|
stream=stream,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(f"Failed to invoke model: {e}")
|
raise InternalServerError(f"Failed to invoke model: {e}")
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
# try to parse error message
|
# try to parse error message
|
||||||
err = resp['error']['code']
|
err = resp["error"]["code"]
|
||||||
msg = resp['error']['message']
|
msg = resp["error"]["message"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InternalServerError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
if err == 'invalid_api_key':
|
if err == "invalid_api_key":
|
||||||
raise InvalidAPIKeyError(msg)
|
raise InvalidAPIKeyError(msg)
|
||||||
elif err == 'insufficient_quota':
|
elif err == "insufficient_quota":
|
||||||
raise InsufficientAccountBalance(msg)
|
raise InsufficientAccountBalance(msg)
|
||||||
elif err == 'invalid_authentication':
|
elif err == "invalid_authentication":
|
||||||
raise InvalidAuthenticationError(msg)
|
raise InvalidAuthenticationError(msg)
|
||||||
elif 'rate' in err:
|
elif "rate" in err:
|
||||||
raise RateLimitReachedError(msg)
|
raise RateLimitReachedError(msg)
|
||||||
elif 'internal' in err:
|
elif "internal" in err:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
elif err == 'api_key_empty':
|
elif err == "api_key_empty":
|
||||||
raise InvalidAPIKeyError(msg)
|
raise InvalidAPIKeyError(msg)
|
||||||
else:
|
else:
|
||||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_stream_generate_response(response)
|
return self._handle_chat_stream_generate_response(response)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,17 +1,22 @@
|
|||||||
class InvalidAuthenticationError(Exception):
|
class InvalidAuthenticationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidAPIKeyError(Exception):
|
class InvalidAPIKeyError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RateLimitReachedError(Exception):
|
class RateLimitReachedError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InsufficientAccountBalance(Exception):
|
class InsufficientAccountBalance(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InternalServerError(Exception):
|
class InternalServerError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BadRequestError(Exception):
|
class BadRequestError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -17,10 +21,19 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import (
|
||||||
|
BaichuanTokenizer,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import (
|
||||||
|
BaichuanMessage,
|
||||||
|
BaichuanModel,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalance,
|
||||||
@ -32,20 +45,43 @@ from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tu
|
|||||||
|
|
||||||
|
|
||||||
class BaichuanLarguageModel(LargeLanguageModel):
|
class BaichuanLarguageModel(LargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
model: str,
|
||||||
stream: bool = True, user: str | None = None) \
|
credentials: dict,
|
||||||
-> LLMResult | Generator:
|
prompt_messages: list[PromptMessage],
|
||||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
model_parameters: dict,
|
||||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
|
return self._generate(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
) -> int:
|
||||||
return self._num_tokens_from_messages(prompt_messages)
|
return self._num_tokens_from_messages(prompt_messages)
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int:
|
def _num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[PromptMessage],
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens for baichuan model"""
|
"""Calculate num tokens for baichuan model"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
return BaichuanTokenizer._get_num_tokens(text)
|
return BaichuanTokenizer._get_num_tokens(text)
|
||||||
|
|
||||||
@ -57,10 +93,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
text = ''
|
text = ""
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item['type'] == 'text':
|
if isinstance(item, dict) and item["type"] == "text":
|
||||||
text += item['text']
|
text += item["text"]
|
||||||
|
|
||||||
value = text
|
value = text
|
||||||
|
|
||||||
@ -87,89 +123,123 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
# ping
|
# ping
|
||||||
instance = BaichuanModel(
|
instance = BaichuanModel(
|
||||||
api_key=credentials['api_key'],
|
api_key=credentials["api_key"], secret_key=credentials.get("secret_key", "")
|
||||||
secret_key=credentials.get('secret_key', '')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
instance.generate(model=model, stream=False, messages=[
|
instance.generate(
|
||||||
BaichuanMessage(content='ping', role='user')
|
model=model,
|
||||||
], parameters={
|
stream=False,
|
||||||
'max_tokens': 1,
|
messages=[BaichuanMessage(content="ping", role="user")],
|
||||||
}, timeout=60)
|
parameters={
|
||||||
|
"max_tokens": 1,
|
||||||
|
},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _generate(
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
self,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
model: str,
|
||||||
-> LLMResult | Generator:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
if tools is not None and len(tools) > 0:
|
if tools is not None and len(tools) > 0:
|
||||||
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
||||||
|
|
||||||
instance = BaichuanModel(
|
instance = BaichuanModel(
|
||||||
api_key=credentials['api_key'],
|
api_key=credentials["api_key"], secret_key=credentials.get("secret_key", "")
|
||||||
secret_key=credentials.get('secret_key', '')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert prompt messages to baichuan messages
|
# convert prompt messages to baichuan messages
|
||||||
messages = [
|
messages = [
|
||||||
BaichuanMessage(
|
BaichuanMessage(
|
||||||
content=message.content if isinstance(message.content, str) else ''.join([
|
content=message.content
|
||||||
content.data for content in message.content
|
if isinstance(message.content, str)
|
||||||
]),
|
else "".join([content.data for content in message.content]),
|
||||||
role=message.role.value
|
role=message.role.value,
|
||||||
) for message in prompt_messages
|
)
|
||||||
|
for message in prompt_messages
|
||||||
]
|
]
|
||||||
|
|
||||||
# invoke model
|
# invoke model
|
||||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60)
|
response = instance.generate(
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
messages=messages,
|
||||||
|
parameters=model_parameters,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
return self._handle_chat_generate_stream_response(
|
||||||
|
model, prompt_messages, credentials, response
|
||||||
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str,
|
return self._handle_chat_generate_response(
|
||||||
prompt_messages: list[PromptMessage],
|
model, prompt_messages, credentials, response
|
||||||
credentials: dict,
|
)
|
||||||
response: BaichuanMessage) -> LLMResult:
|
|
||||||
|
def _handle_chat_generate_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
credentials: dict,
|
||||||
|
response: BaichuanMessage,
|
||||||
|
) -> LLMResult:
|
||||||
# convert baichuan message to llm result
|
# convert baichuan message to llm result
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens'])
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=response.usage["prompt_tokens"],
|
||||||
|
completion_tokens=response.usage["completion_tokens"],
|
||||||
|
)
|
||||||
return LLMResult(
|
return LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(content=response.content, tool_calls=[]),
|
||||||
content=response.content,
|
|
||||||
tool_calls=[]
|
|
||||||
),
|
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str,
|
def _handle_chat_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
credentials: dict,
|
model: str,
|
||||||
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
prompt_messages: list[PromptMessage],
|
||||||
|
credentials: dict,
|
||||||
|
response: Generator[BaichuanMessage, None, None],
|
||||||
|
) -> Generator:
|
||||||
for message in response:
|
for message in response:
|
||||||
if message.usage:
|
if message.usage:
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens'])
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=message.usage["prompt_tokens"],
|
||||||
|
completion_tokens=message.usage["completion_tokens"],
|
||||||
|
)
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=message.content,
|
content=message.content, tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
),
|
),
|
||||||
usage=usage,
|
usage=usage,
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason
|
||||||
|
if message.stop_reason
|
||||||
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -179,10 +249,11 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=message.content,
|
content=message.content, tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
),
|
),
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason
|
||||||
|
if message.stop_reason
|
||||||
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -197,21 +268,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [],
|
||||||
],
|
InvokeServerUnavailableError: [InternalServerError],
|
||||||
InvokeServerUnavailableError: [
|
InvokeRateLimitError: [RateLimitReachedError],
|
||||||
InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
RateLimitReachedError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalance,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||||
BadRequestError,
|
|
||||||
KeyError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,10 @@ from typing import Optional
|
|||||||
from requests import post
|
from requests import post
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
EmbeddingUsage,
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -14,9 +17,15 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import (
|
||||||
|
BaichuanTokenizer,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
from model_providers.core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalance,
|
||||||
@ -31,11 +40,16 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
"""
|
"""
|
||||||
Model class for BaiChuan text embedding model.
|
Model class for BaiChuan text embedding model.
|
||||||
"""
|
"""
|
||||||
api_base: str = 'http://api.baichuan-ai.com/v1/embeddings'
|
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
api_base: str = "http://api.baichuan-ai.com/v1/embeddings"
|
||||||
texts: list[str], user: Optional[str] = None) \
|
|
||||||
-> TextEmbeddingResult:
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding model
|
Invoke text embedding model
|
||||||
|
|
||||||
@ -45,27 +59,24 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
api_key = credentials['api_key']
|
api_key = credentials["api_key"]
|
||||||
if model != 'baichuan-text-embedding':
|
if model != "baichuan-text-embedding":
|
||||||
raise ValueError('Invalid model name')
|
raise ValueError("Invalid model name")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise CredentialsValidateFailedError('api_key is required')
|
raise CredentialsValidateFailedError("api_key is required")
|
||||||
|
|
||||||
# split into chunks of batch size 16
|
# split into chunks of batch size 16
|
||||||
chunks = []
|
chunks = []
|
||||||
for i in range(0, len(texts), 16):
|
for i in range(0, len(texts), 16):
|
||||||
chunks.append(texts[i:i + 16])
|
chunks.append(texts[i : i + 16])
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
token_usage = 0
|
token_usage = 0
|
||||||
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
# embeding chunk
|
# embedding chunk
|
||||||
chunk_embeddings, chunk_usage = self.embedding(
|
chunk_embeddings, chunk_usage = self.embedding(
|
||||||
model=model,
|
model=model, api_key=api_key, texts=chunk, user=user
|
||||||
api_key=api_key,
|
|
||||||
texts=chunk,
|
|
||||||
user=user
|
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings.extend(chunk_embeddings)
|
embeddings.extend(chunk_embeddings)
|
||||||
@ -75,16 +86,15 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
usage=self._calc_response_usage(
|
usage=self._calc_response_usage(
|
||||||
model=model,
|
model=model, credentials=credentials, tokens=token_usage
|
||||||
credentials=credentials,
|
),
|
||||||
tokens=token_usage
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \
|
def embedding(
|
||||||
-> tuple[list[list[float]], int]:
|
self, model: str, api_key, texts: list[str], user: Optional[str] = None
|
||||||
|
) -> tuple[list[list[float]], int]:
|
||||||
"""
|
"""
|
||||||
Embed given texts
|
Embed given texts
|
||||||
|
|
||||||
@ -96,55 +106,53 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
"""
|
"""
|
||||||
url = self.api_base
|
url = self.api_base
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': 'Bearer ' + api_key,
|
"Authorization": "Bearer " + api_key,
|
||||||
'Content-Type': 'application/json'
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {
|
data = {"model": "Baichuan-Text-Embedding", "input": texts}
|
||||||
'model': 'Baichuan-Text-Embedding',
|
|
||||||
'input': texts
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(url, headers=headers, data=dumps(data))
|
response = post(url, headers=headers, data=dumps(data))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InvokeConnectionError(str(e))
|
raise InvokeConnectionError(str(e))
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
# try to parse error message
|
# try to parse error message
|
||||||
err = resp['error']['code']
|
err = resp["error"]["code"]
|
||||||
msg = resp['error']['message']
|
msg = resp["error"]["message"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InternalServerError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
if err == 'invalid_api_key':
|
if err == "invalid_api_key":
|
||||||
raise InvalidAPIKeyError(msg)
|
raise InvalidAPIKeyError(msg)
|
||||||
elif err == 'insufficient_quota':
|
elif err == "insufficient_quota":
|
||||||
raise InsufficientAccountBalance(msg)
|
raise InsufficientAccountBalance(msg)
|
||||||
elif err == 'invalid_authentication':
|
elif err == "invalid_authentication":
|
||||||
raise InvalidAuthenticationError(msg)
|
raise InvalidAuthenticationError(msg)
|
||||||
elif err and 'rate' in err:
|
elif err and "rate" in err:
|
||||||
raise RateLimitReachedError(msg)
|
raise RateLimitReachedError(msg)
|
||||||
elif err and 'internal' in err:
|
elif err and "internal" in err:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
elif err == 'api_key_empty':
|
elif err == "api_key_empty":
|
||||||
raise InvalidAPIKeyError(msg)
|
raise InvalidAPIKeyError(msg)
|
||||||
else:
|
else:
|
||||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
embeddings = resp['data']
|
embeddings = resp["data"]
|
||||||
usage = resp['usage']
|
usage = resp["usage"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InternalServerError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
return [
|
)
|
||||||
data['embedding'] for data in embeddings
|
|
||||||
], usage['total_tokens']
|
|
||||||
|
|
||||||
|
return [data["embedding"] for data in embeddings], usage["total_tokens"]
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
"""
|
"""
|
||||||
@ -170,33 +178,27 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
except InvalidAPIKeyError:
|
except InvalidAPIKeyError:
|
||||||
raise CredentialsValidateFailedError('Invalid api key')
|
raise CredentialsValidateFailedError("Invalid api key")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [],
|
||||||
],
|
InvokeServerUnavailableError: [InternalServerError],
|
||||||
InvokeServerUnavailableError: [
|
InvokeRateLimitError: [RateLimitReachedError],
|
||||||
InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
RateLimitReachedError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalance,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||||
BadRequestError,
|
|
||||||
KeyError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
|
||||||
@ -210,7 +212,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -221,7 +223,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BedrockProvider(ModelProvider):
|
class BedrockProvider(ModelProvider):
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -20,11 +25,12 @@ class BedrockProvider(ModelProvider):
|
|||||||
|
|
||||||
# Use `gemini-pro` model for validate,
|
# Use `gemini-pro` model for validate,
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='amazon.titan-text-lite-v1',
|
model="amazon.titan-text-lite-v1", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -13,7 +13,11 @@ from botocore.exceptions import (
|
|||||||
UnknownServiceError,
|
UnknownServiceError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -29,18 +33,28 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
def _invoke(
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
model: str,
|
||||||
-> Union[LLMResult, Generator]:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -55,10 +69,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# invoke model
|
# invoke model
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate(
|
||||||
|
model, credentials, prompt_messages, model_parameters, stop, stream, user
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
messages: list[PromptMessage] | str,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
@ -68,7 +89,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return:md = genai.GenerativeModel(model)
|
:return:md = genai.GenerativeModel(model)
|
||||||
"""
|
"""
|
||||||
prefix = model.split('.')[0]
|
prefix = model.split(".")[0]
|
||||||
|
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
prompt = messages
|
prompt = messages
|
||||||
@ -76,8 +97,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
prompt = self._convert_messages_to_prompt(messages, prefix)
|
prompt = self._convert_messages_to_prompt(messages, prefix)
|
||||||
|
|
||||||
return self._get_num_tokens_by_gpt2(prompt)
|
return self._get_num_tokens_by_gpt2(prompt)
|
||||||
|
|
||||||
def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
|
def _convert_messages_to_prompt(
|
||||||
|
self, model_prefix: str, messages: list[PromptMessage]
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format a list of messages into a full prompt for the Google model
|
Format a list of messages into a full prompt for the Google model
|
||||||
|
|
||||||
@ -85,7 +108,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||||
"""
|
"""
|
||||||
messages = messages.copy() # don't mutate the original list
|
messages = messages.copy() # don't mutate the original list
|
||||||
|
|
||||||
text = "".join(
|
text = "".join(
|
||||||
self._convert_one_message_to_text(message, model_prefix)
|
self._convert_one_message_to_text(message, model_prefix)
|
||||||
for message in messages
|
for message in messages
|
||||||
@ -101,32 +124,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ping_message = UserPromptMessage(content="ping")
|
ping_message = UserPromptMessage(content="ping")
|
||||||
self._generate(model=model,
|
self._generate(
|
||||||
credentials=credentials,
|
model=model,
|
||||||
prompt_messages=[ping_message],
|
credentials=credentials,
|
||||||
model_parameters={},
|
prompt_messages=[ping_message],
|
||||||
stream=False)
|
model_parameters={},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
except ClientError as ex:
|
except ClientError as ex:
|
||||||
error_code = ex.response['Error']['Code']
|
error_code = ex.response["Error"]["Code"]
|
||||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||||
|
|
||||||
raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg)))
|
raise CredentialsValidateFailedError(
|
||||||
|
str(self._map_client_to_invoke_error(error_code, full_error_msg))
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
|
def _convert_one_message_to_text(
|
||||||
|
self, message: PromptMessage, model_prefix: str
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a single message to a string.
|
Convert a single message to a string.
|
||||||
|
|
||||||
:param message: PromptMessage to convert.
|
:param message: PromptMessage to convert.
|
||||||
:return: String representation of the message.
|
:return: String representation of the message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if model_prefix == "anthropic":
|
if model_prefix == "anthropic":
|
||||||
human_prompt_prefix = "\n\nHuman:"
|
human_prompt_prefix = "\n\nHuman:"
|
||||||
human_prompt_postfix = ""
|
human_prompt_postfix = ""
|
||||||
@ -141,7 +170,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
human_prompt_prefix = "\n\nUser:"
|
human_prompt_prefix = "\n\nUser:"
|
||||||
human_prompt_postfix = ""
|
human_prompt_postfix = ""
|
||||||
ai_prompt = "\n\nBot:"
|
ai_prompt = "\n\nBot:"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
human_prompt_prefix = ""
|
human_prompt_prefix = ""
|
||||||
human_prompt_postfix = ""
|
human_prompt_postfix = ""
|
||||||
@ -160,7 +189,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str:
|
def _convert_messages_to_prompt(
|
||||||
|
self, messages: list[PromptMessage], model_prefix: str
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
|
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
|
||||||
|
|
||||||
@ -168,7 +199,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
messages = messages.copy() # don't mutate the original list
|
messages = messages.copy() # don't mutate the original list
|
||||||
if not isinstance(messages[-1], AssistantPromptMessage):
|
if not isinstance(messages[-1], AssistantPromptMessage):
|
||||||
@ -182,23 +213,36 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||||
return text.rstrip()
|
return text.rstrip()
|
||||||
|
|
||||||
def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
def _create_payload(
|
||||||
|
self,
|
||||||
|
model_prefix: str,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create payload for bedrock api call depending on model provider
|
Create payload for bedrock api call depending on model provider
|
||||||
"""
|
"""
|
||||||
payload = dict()
|
payload = dict()
|
||||||
|
|
||||||
if model_prefix == "amazon":
|
if model_prefix == "amazon":
|
||||||
payload["textGenerationConfig"] = { **model_parameters }
|
payload["textGenerationConfig"] = {**model_parameters}
|
||||||
payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else [])
|
payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (
|
||||||
|
stop if stop else []
|
||||||
payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
)
|
||||||
|
|
||||||
|
payload["inputText"] = self._convert_messages_to_prompt(
|
||||||
|
prompt_messages, model_prefix
|
||||||
|
)
|
||||||
|
|
||||||
elif model_prefix == "ai21":
|
elif model_prefix == "ai21":
|
||||||
payload["temperature"] = model_parameters.get("temperature")
|
payload["temperature"] = model_parameters.get("temperature")
|
||||||
payload["topP"] = model_parameters.get("topP")
|
payload["topP"] = model_parameters.get("topP")
|
||||||
payload["maxTokens"] = model_parameters.get("maxTokens")
|
payload["maxTokens"] = model_parameters.get("maxTokens")
|
||||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
payload["prompt"] = self._convert_messages_to_prompt(
|
||||||
|
prompt_messages, model_prefix
|
||||||
|
)
|
||||||
|
|
||||||
# jurassic models only support a single stop sequence
|
# jurassic models only support a single stop sequence
|
||||||
if stop:
|
if stop:
|
||||||
@ -212,28 +256,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
||||||
|
|
||||||
elif model_prefix == "anthropic":
|
elif model_prefix == "anthropic":
|
||||||
payload = { **model_parameters }
|
payload = {**model_parameters}
|
||||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
payload["prompt"] = self._convert_messages_to_prompt(
|
||||||
|
prompt_messages, model_prefix
|
||||||
|
)
|
||||||
payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else [])
|
payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else [])
|
||||||
|
|
||||||
elif model_prefix == "cohere":
|
elif model_prefix == "cohere":
|
||||||
payload = { **model_parameters }
|
payload = {**model_parameters}
|
||||||
payload["prompt"] = prompt_messages[0].content
|
payload["prompt"] = prompt_messages[0].content
|
||||||
payload["stream"] = stream
|
payload["stream"] = stream
|
||||||
|
|
||||||
elif model_prefix == "meta":
|
elif model_prefix == "meta":
|
||||||
payload = { **model_parameters }
|
payload = {**model_parameters}
|
||||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
payload["prompt"] = self._convert_messages_to_prompt(
|
||||||
|
prompt_messages, model_prefix
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown model prefix {model_prefix}")
|
raise ValueError(f"Got unknown model prefix {model_prefix}")
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
model: str,
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -246,19 +300,19 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
client_config = Config(
|
client_config = Config(region_name=credentials["aws_region"])
|
||||||
region_name=credentials["aws_region"]
|
|
||||||
)
|
|
||||||
|
|
||||||
runtime_client = boto3.client(
|
runtime_client = boto3.client(
|
||||||
service_name='bedrock-runtime',
|
service_name="bedrock-runtime",
|
||||||
config=client_config,
|
config=client_config,
|
||||||
aws_access_key_id=credentials["aws_access_key_id"],
|
aws_access_key_id=credentials["aws_access_key_id"],
|
||||||
aws_secret_access_key=credentials["aws_secret_access_key"]
|
aws_secret_access_key=credentials["aws_secret_access_key"],
|
||||||
)
|
)
|
||||||
|
|
||||||
model_prefix = model.split('.')[0]
|
model_prefix = model.split(".")[0]
|
||||||
payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
|
payload = self._create_payload(
|
||||||
|
model_prefix, prompt_messages, model_parameters, stop, stream
|
||||||
|
)
|
||||||
|
|
||||||
# need workaround for ai21 models which doesn't support streaming
|
# need workaround for ai21 models which doesn't support streaming
|
||||||
if stream and model_prefix != "ai21":
|
if stream and model_prefix != "ai21":
|
||||||
@ -267,18 +321,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
invoke = runtime_client.invoke_model
|
invoke = runtime_client.invoke_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body_jsonstr=json.dumps(payload)
|
body_jsonstr = json.dumps(payload)
|
||||||
response = invoke(
|
response = invoke(
|
||||||
modelId=model,
|
modelId=model,
|
||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
accept= "*/*",
|
accept="*/*",
|
||||||
body=body_jsonstr
|
body=body_jsonstr,
|
||||||
)
|
)
|
||||||
except ClientError as ex:
|
except ClientError as ex:
|
||||||
error_code = ex.response['Error']['Code']
|
error_code = ex.response["Error"]["Code"]
|
||||||
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||||
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
||||||
|
|
||||||
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
||||||
raise InvokeConnectionError(str(ex))
|
raise InvokeConnectionError(str(ex))
|
||||||
|
|
||||||
@ -287,15 +341,23 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise InvokeError(str(ex))
|
raise InvokeError(str(ex))
|
||||||
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_stream_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: dict,
|
def _handle_generate_response(
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm response
|
Handle llm response
|
||||||
|
|
||||||
@ -305,7 +367,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response
|
:return: llm response
|
||||||
"""
|
"""
|
||||||
response_body = json.loads(response.get('body').read().decode('utf-8'))
|
response_body = json.loads(response.get("body").read().decode("utf-8"))
|
||||||
|
|
||||||
finish_reason = response_body.get("error")
|
finish_reason = response_body.get("error")
|
||||||
|
|
||||||
@ -313,43 +375,51 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
raise InvokeError(finish_reason)
|
raise InvokeError(finish_reason)
|
||||||
|
|
||||||
# get output text and calculate num tokens based on model / provider
|
# get output text and calculate num tokens based on model / provider
|
||||||
model_prefix = model.split('.')[0]
|
model_prefix = model.split(".")[0]
|
||||||
|
|
||||||
if model_prefix == "amazon":
|
if model_prefix == "amazon":
|
||||||
output = response_body.get("results")[0].get("outputText").strip('\n')
|
output = response_body.get("results")[0].get("outputText").strip("\n")
|
||||||
prompt_tokens = response_body.get("inputTextTokenCount")
|
prompt_tokens = response_body.get("inputTextTokenCount")
|
||||||
completion_tokens = response_body.get("results")[0].get("tokenCount")
|
completion_tokens = response_body.get("results")[0].get("tokenCount")
|
||||||
|
|
||||||
elif model_prefix == "ai21":
|
elif model_prefix == "ai21":
|
||||||
output = response_body.get('completions')[0].get('data').get('text')
|
output = response_body.get("completions")[0].get("data").get("text")
|
||||||
prompt_tokens = len(response_body.get("prompt").get("tokens"))
|
prompt_tokens = len(response_body.get("prompt").get("tokens"))
|
||||||
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
|
completion_tokens = len(
|
||||||
|
response_body.get("completions")[0].get("data").get("tokens")
|
||||||
|
)
|
||||||
|
|
||||||
elif model_prefix == "anthropic":
|
elif model_prefix == "anthropic":
|
||||||
output = response_body.get("completion")
|
output = response_body.get("completion")
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
|
completion_tokens = self.get_num_tokens(
|
||||||
|
model, credentials, output if output else ""
|
||||||
|
)
|
||||||
|
|
||||||
elif model_prefix == "cohere":
|
elif model_prefix == "cohere":
|
||||||
output = response_body.get("generations")[0].get("text")
|
output = response_body.get("generations")[0].get("text")
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, output if output else '')
|
completion_tokens = self.get_num_tokens(
|
||||||
|
model, credentials, output if output else ""
|
||||||
|
)
|
||||||
|
|
||||||
elif model_prefix == "meta":
|
elif model_prefix == "meta":
|
||||||
output = response_body.get("generation").strip('\n')
|
output = response_body.get("generation").strip("\n")
|
||||||
prompt_tokens = response_body.get("prompt_token_count")
|
prompt_tokens = response_body.get("prompt_token_count")
|
||||||
completion_tokens = response_body.get("generation_token_count")
|
completion_tokens = response_body.get("generation_token_count")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
raise ValueError(
|
||||||
|
f"Got unknown model prefix {model_prefix} when handling block response"
|
||||||
|
)
|
||||||
|
|
||||||
# construct assistant message from output
|
# construct assistant message from output
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=output)
|
||||||
content=output
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate usage
|
# calculate usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# construct response
|
# construct response
|
||||||
result = LLMResult(
|
result = LLMResult(
|
||||||
@ -361,8 +431,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict,
|
def _handle_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm stream response
|
Handle llm stream response
|
||||||
|
|
||||||
@ -372,48 +447,52 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response chunk generator result
|
:return: llm response chunk generator result
|
||||||
"""
|
"""
|
||||||
model_prefix = model.split('.')[0]
|
model_prefix = model.split(".")[0]
|
||||||
if model_prefix == "ai21":
|
if model_prefix == "ai21":
|
||||||
response_body = json.loads(response.get('body').read().decode('utf-8'))
|
response_body = json.loads(response.get("body").read().decode("utf-8"))
|
||||||
|
|
||||||
content = response_body.get('completions')[0].get('data').get('text')
|
content = response_body.get("completions")[0].get("data").get("text")
|
||||||
finish_reason = response_body.get('completions')[0].get('finish_reason')
|
finish_reason = response_body.get("completions")[0].get("finish_reason")
|
||||||
|
|
||||||
prompt_tokens = len(response_body.get("prompt").get("tokens"))
|
prompt_tokens = len(response_body.get("prompt").get("tokens"))
|
||||||
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens'))
|
completion_tokens = len(
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
response_body.get("completions")[0].get("data").get("tokens")
|
||||||
|
)
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(content=content),
|
message=AssistantPromptMessage(content=content),
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
stream = response.get('body')
|
stream = response.get("body")
|
||||||
if not stream:
|
if not stream:
|
||||||
raise InvokeError('No response body')
|
raise InvokeError("No response body")
|
||||||
|
|
||||||
index = -1
|
index = -1
|
||||||
for event in stream:
|
for event in stream:
|
||||||
chunk = event.get('chunk')
|
chunk = event.get("chunk")
|
||||||
|
|
||||||
if not chunk:
|
if not chunk:
|
||||||
exception_name = next(iter(event))
|
exception_name = next(iter(event))
|
||||||
full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
|
full_ex_msg = f"{exception_name}: {event[exception_name]['message']}"
|
||||||
raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
|
raise self._map_client_to_invoke_error(exception_name, full_ex_msg)
|
||||||
|
|
||||||
payload = json.loads(chunk.get('bytes').decode())
|
payload = json.loads(chunk.get("bytes").decode())
|
||||||
|
|
||||||
model_prefix = model.split('.')[0]
|
model_prefix = model.split(".")[0]
|
||||||
if model_prefix == "amazon":
|
if model_prefix == "amazon":
|
||||||
content_delta = payload.get("outputText").strip('\n')
|
content_delta = payload.get("outputText").strip("\n")
|
||||||
finish_reason = payload.get("completion_reason")
|
finish_reason = payload.get("completion_reason")
|
||||||
|
|
||||||
elif model_prefix == "anthropic":
|
elif model_prefix == "anthropic":
|
||||||
content_delta = payload.get("completion")
|
content_delta = payload.get("completion")
|
||||||
finish_reason = payload.get("stop_reason")
|
finish_reason = payload.get("stop_reason")
|
||||||
@ -421,38 +500,45 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
elif model_prefix == "cohere":
|
elif model_prefix == "cohere":
|
||||||
content_delta = payload.get("text")
|
content_delta = payload.get("text")
|
||||||
finish_reason = payload.get("finish_reason")
|
finish_reason = payload.get("finish_reason")
|
||||||
|
|
||||||
elif model_prefix == "meta":
|
elif model_prefix == "meta":
|
||||||
content_delta = payload.get("generation").strip('\n')
|
content_delta = payload.get("generation").strip("\n")
|
||||||
finish_reason = payload.get("stop_reason")
|
finish_reason = payload.get("stop_reason")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response")
|
raise ValueError(
|
||||||
|
f"Got unknown model prefix {model_prefix} when handling stream response"
|
||||||
|
)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content = content_delta if content_delta else '',
|
content=content_delta if content_delta else "",
|
||||||
)
|
)
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
if not finish_reason:
|
if not finish_reason:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index, message=assistant_prompt_message
|
||||||
message=assistant_prompt_message
|
),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# get num tokens from metrics in last chunk
|
# get num tokens from metrics in last chunk
|
||||||
prompt_tokens = payload["amazon-bedrock-invocationMetrics"]["inputTokenCount"]
|
prompt_tokens = payload["amazon-bedrock-invocationMetrics"][
|
||||||
completion_tokens = payload["amazon-bedrock-invocationMetrics"]["outputTokenCount"]
|
"inputTokenCount"
|
||||||
|
]
|
||||||
|
completion_tokens = payload["amazon-bedrock-invocationMetrics"][
|
||||||
|
"outputTokenCount"
|
||||||
|
]
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -460,10 +546,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
index=index,
|
index=index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
"""
|
"""
|
||||||
@ -479,10 +565,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
InvokeServerUnavailableError: [],
|
InvokeServerUnavailableError: [],
|
||||||
InvokeRateLimitError: [],
|
InvokeRateLimitError: [],
|
||||||
InvokeAuthorizationError: [],
|
InvokeAuthorizationError: [],
|
||||||
InvokeBadRequestError: []
|
InvokeBadRequestError: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
|
def _map_client_to_invoke_error(
|
||||||
|
self, error_code: str, error_msg: str
|
||||||
|
) -> type[InvokeError]:
|
||||||
"""
|
"""
|
||||||
Map client error to invoke error
|
Map client error to invoke error
|
||||||
|
|
||||||
@ -497,7 +585,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
return InvokeBadRequestError(error_msg)
|
return InvokeBadRequestError(error_msg)
|
||||||
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
||||||
return InvokeRateLimitError(error_msg)
|
return InvokeRateLimitError(error_msg)
|
||||||
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
|
elif error_code in [
|
||||||
|
"ModelTimeoutException",
|
||||||
|
"ModelErrorException",
|
||||||
|
"InternalServerException",
|
||||||
|
"ModelNotReadyException",
|
||||||
|
]:
|
||||||
return InvokeServerUnavailableError(error_msg)
|
return InvokeServerUnavailableError(error_msg)
|
||||||
elif error_code == "ModelStreamErrorException":
|
elif error_code == "ModelStreamErrorException":
|
||||||
return InvokeConnectionError(error_msg)
|
return InvokeConnectionError(error_msg)
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -21,11 +25,12 @@ class ChatGLMProvider(ModelProvider):
|
|||||||
|
|
||||||
# Use `chatglm3-6b` model for validate,
|
# Use `chatglm3-6b` model for validate,
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='chatglm3-6b',
|
model="chatglm3-6b", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -20,7 +20,11 @@ from openai import (
|
|||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
from openai.types.chat.chat_completion_message import FunctionCall
|
from openai.types.chat.chat_completion_message import FunctionCall
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -37,18 +41,29 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.utils import helper
|
from model_providers.core.model_runtime.utils import helper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
model: str,
|
||||||
stream: bool = True, user: str | None = None) \
|
credentials: dict,
|
||||||
-> LLMResult | Generator:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -71,11 +86,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
@ -96,11 +116,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._invoke(model=model, credentials=credentials, prompt_messages=[
|
self._invoke(
|
||||||
UserPromptMessage(content="ping"),
|
model=model,
|
||||||
], model_parameters={
|
credentials=credentials,
|
||||||
"max_tokens": 16,
|
prompt_messages=[
|
||||||
})
|
UserPromptMessage(content="ping"),
|
||||||
|
],
|
||||||
|
model_parameters={
|
||||||
|
"max_tokens": 16,
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CredentialsValidateFailedError(str(e))
|
raise CredentialsValidateFailedError(str(e))
|
||||||
|
|
||||||
@ -124,24 +149,24 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
ConflictError,
|
ConflictError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnprocessableEntityError,
|
UnprocessableEntityError,
|
||||||
PermissionDeniedError
|
PermissionDeniedError,
|
||||||
],
|
],
|
||||||
InvokeRateLimitError: [
|
InvokeRateLimitError: [RateLimitError],
|
||||||
RateLimitError
|
InvokeAuthorizationError: [AuthenticationError],
|
||||||
],
|
InvokeBadRequestError: [ValueError],
|
||||||
InvokeAuthorizationError: [
|
|
||||||
AuthenticationError
|
|
||||||
],
|
|
||||||
InvokeBadRequestError: [
|
|
||||||
ValueError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
model: str,
|
||||||
stream: bool = True, user: str | None = None) \
|
credentials: dict,
|
||||||
-> LLMResult | Generator:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -155,7 +180,9 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._check_chatglm_parameters(model=model, model_parameters=model_parameters, tools=tools)
|
self._check_chatglm_parameters(
|
||||||
|
model=model, model_parameters=model_parameters, tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
kwargs = self._to_client_kwargs(credentials)
|
kwargs = self._to_client_kwargs(credentials)
|
||||||
# init model client
|
# init model client
|
||||||
@ -163,13 +190,13 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop'] = stop
|
extra_model_kwargs["stop"] = stop
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
if tools and len(tools) > 0:
|
if tools and len(tools) > 0:
|
||||||
extra_model_kwargs['functions'] = [
|
extra_model_kwargs["functions"] = [
|
||||||
helper.dump_model(tool) for tool in tools
|
helper.dump_model(tool) for tool in tools
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -178,21 +205,29 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
model=model,
|
model=model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
**extra_model_kwargs
|
**extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(
|
return self._handle_chat_generate_stream_response(
|
||||||
model=model, credentials=credentials, response=result, tools=tools,
|
model=model,
|
||||||
prompt_messages=prompt_messages
|
credentials=credentials,
|
||||||
|
response=result,
|
||||||
|
tools=tools,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(
|
return self._handle_chat_generate_response(
|
||||||
model=model, credentials=credentials, response=result, tools=tools,
|
model=model,
|
||||||
prompt_messages=prompt_messages
|
credentials=credentials,
|
||||||
|
response=result,
|
||||||
|
tools=tools,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None:
|
def _check_chatglm_parameters(
|
||||||
|
self, model: str, model_parameters: dict, tools: list[PromptMessageTool]
|
||||||
|
) -> None:
|
||||||
if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0:
|
if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0:
|
||||||
raise InvokeBadRequestError("ChatGLM2 does not support function calling")
|
raise InvokeBadRequestError("ChatGLM2 does not support function calling")
|
||||||
|
|
||||||
@ -212,7 +247,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
if message.tool_calls and len(message.tool_calls) > 0:
|
if message.tool_calls and len(message.tool_calls) > 0:
|
||||||
message_dict["function_call"] = {
|
message_dict["function_call"] = {
|
||||||
"name": message.tool_calls[0].function.name,
|
"name": message.tool_calls[0].function.name,
|
||||||
"arguments": message.tool_calls[0].function.arguments
|
"arguments": message.tool_calls[0].function.arguments,
|
||||||
}
|
}
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
@ -223,12 +258,12 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
message_dict = {"role": "function", "content": message.content}
|
message_dict = {"role": "function", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
def _extract_response_tool_calls(self,
|
def _extract_response_tool_calls(
|
||||||
response_function_calls: list[FunctionCall]) \
|
self, response_function_calls: list[FunctionCall]
|
||||||
-> list[AssistantPromptMessage.ToolCall]:
|
) -> list[AssistantPromptMessage.ToolCall]:
|
||||||
"""
|
"""
|
||||||
Extract tool calls from response
|
Extract tool calls from response
|
||||||
|
|
||||||
@ -239,19 +274,16 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
if response_function_calls:
|
if response_function_calls:
|
||||||
for response_tool_call in response_function_calls:
|
for response_tool_call in response_function_calls:
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_tool_call.name,
|
name=response_tool_call.name, arguments=response_tool_call.arguments
|
||||||
arguments=response_tool_call.arguments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=0,
|
id=0, type="function", function=function
|
||||||
type='function',
|
|
||||||
function=function
|
|
||||||
)
|
)
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
def _to_client_kwargs(self, credentials: dict) -> dict:
|
def _to_client_kwargs(self, credentials: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Convert invoke kwargs to client kwargs
|
Convert invoke kwargs to client kwargs
|
||||||
@ -265,17 +297,20 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
client_kwargs = {
|
client_kwargs = {
|
||||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||||
"api_key": "1",
|
"api_key": "1",
|
||||||
"base_url": join(credentials['api_base'], 'v1')
|
"base_url": join(credentials["api_base"], "v1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
return client_kwargs
|
return client_kwargs
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk],
|
def _handle_chat_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None) \
|
model: str,
|
||||||
-> Generator:
|
credentials: dict,
|
||||||
|
response: Stream[ChatCompletionChunk],
|
||||||
full_response = ''
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> Generator:
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
@ -283,35 +318,46 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
if delta.finish_reason is None and (
|
||||||
|
delta.delta.content is None or delta.delta.content == ""
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check if there is a tool call in the response
|
# check if there is a tool call in the response
|
||||||
function_calls = None
|
function_calls = None
|
||||||
if delta.delta.function_call:
|
if delta.delta.function_call:
|
||||||
function_calls = [delta.delta.function_call]
|
function_calls = [delta.delta.function_call]
|
||||||
|
|
||||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
|
assistant_message_tool_calls = self._extract_response_tool_calls(
|
||||||
|
function_calls if function_calls else []
|
||||||
|
)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.delta.content if delta.delta.content else '',
|
content=delta.delta.content if delta.delta.content else "",
|
||||||
tool_calls=assistant_message_tool_calls
|
tool_calls=assistant_message_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
# temp_assistant_prompt_message is used to calculate usage
|
# temp_assistant_prompt_message is used to calculate usage
|
||||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=full_response,
|
content=full_response, tool_calls=assistant_message_tool_calls
|
||||||
tool_calls=assistant_message_tool_calls
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
prompt_tokens = self._num_tokens_from_messages(
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
messages=prompt_messages, tools=tools
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
messages=[temp_assistant_prompt_message], tools=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
||||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -320,7 +366,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
index=delta.index,
|
index=delta.index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=delta.finish_reason,
|
finish_reason=delta.finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -335,11 +381,15 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
full_response += delta.delta.content
|
full_response += delta.delta.content
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion,
|
def _handle_chat_generate_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None) \
|
model: str,
|
||||||
-> LLMResult:
|
credentials: dict,
|
||||||
|
response: ChatCompletion,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm chat response
|
Handle llm chat response
|
||||||
|
|
||||||
@ -356,18 +406,28 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
# convert function call to tool call
|
# convert function call to tool call
|
||||||
function_calls = assistant_message.function_call
|
function_calls = assistant_message.function_call
|
||||||
tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else [])
|
tool_calls = self._extract_response_tool_calls(
|
||||||
|
[function_calls] if function_calls else []
|
||||||
|
)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=assistant_message.content,
|
content=assistant_message.content, tool_calls=tool_calls
|
||||||
tool_calls=tool_calls
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
prompt_tokens = self._num_tokens_from_messages(
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
messages=prompt_messages, tools=tools
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
messages=[assistant_prompt_message], tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
@ -378,8 +438,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
|
def _num_tokens_from_string(
|
||||||
|
self, text: str, tools: Optional[list[PromptMessageTool]] = None
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate num tokens for text completion model with tiktoken package.
|
Calculate num tokens for text completion model with tiktoken package.
|
||||||
|
|
||||||
@ -395,17 +457,21 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, messages: list[PromptMessage],
|
def _num_tokens_from_messages(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer.
|
"""Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer.
|
||||||
|
|
||||||
it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer,
|
it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer,
|
||||||
As a temporary solution we use GPT2 tokenizer instead.
|
As a temporary solution we use GPT2 tokenizer instead.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
return self._get_num_tokens_by_gpt2(text)
|
return self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
tokens_per_message = 3
|
tokens_per_message = 3
|
||||||
tokens_per_name = 1
|
tokens_per_name = 1
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
@ -414,10 +480,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
text = ''
|
text = ""
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item['type'] == 'text':
|
if isinstance(item, dict) and item["type"] == "text":
|
||||||
text += item['text']
|
text += item["text"]
|
||||||
value = text
|
value = text
|
||||||
|
|
||||||
if key == "function_call":
|
if key == "function_call":
|
||||||
@ -452,36 +518,37 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return: number of tokens
|
:return: number of tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
return self._get_num_tokens_by_gpt2(text)
|
return self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
# calculate num tokens for function object
|
# calculate num tokens for function object
|
||||||
num_tokens += tokens('name')
|
num_tokens += tokens("name")
|
||||||
num_tokens += tokens(tool.name)
|
num_tokens += tokens(tool.name)
|
||||||
num_tokens += tokens('description')
|
num_tokens += tokens("description")
|
||||||
num_tokens += tokens(tool.description)
|
num_tokens += tokens(tool.description)
|
||||||
parameters = tool.parameters
|
parameters = tool.parameters
|
||||||
num_tokens += tokens('parameters')
|
num_tokens += tokens("parameters")
|
||||||
num_tokens += tokens('type')
|
num_tokens += tokens("type")
|
||||||
num_tokens += tokens(parameters.get("type"))
|
num_tokens += tokens(parameters.get("type"))
|
||||||
if 'properties' in parameters:
|
if "properties" in parameters:
|
||||||
num_tokens += tokens('properties')
|
num_tokens += tokens("properties")
|
||||||
for key, value in parameters.get('properties').items():
|
for key, value in parameters.get("properties").items():
|
||||||
num_tokens += tokens(key)
|
num_tokens += tokens(key)
|
||||||
for field_key, field_value in value.items():
|
for field_key, field_value in value.items():
|
||||||
num_tokens += tokens(field_key)
|
num_tokens += tokens(field_key)
|
||||||
if field_key == 'enum':
|
if field_key == "enum":
|
||||||
for enum_field in field_value:
|
for enum_field in field_value:
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
num_tokens += tokens(enum_field)
|
num_tokens += tokens(enum_field)
|
||||||
else:
|
else:
|
||||||
num_tokens += tokens(field_key)
|
num_tokens += tokens(field_key)
|
||||||
num_tokens += tokens(str(field_value))
|
num_tokens += tokens(str(field_value))
|
||||||
if 'required' in parameters:
|
if "required" in parameters:
|
||||||
num_tokens += tokens('required')
|
num_tokens += tokens("required")
|
||||||
for required_field in parameters['required']:
|
for required_field in parameters["required"]:
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
num_tokens += tokens(required_field)
|
num_tokens += tokens(required_field)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -21,11 +25,12 @@ class CohereProvider(ModelProvider):
|
|||||||
|
|
||||||
# Use `rerank-english-v2.0` model for validate,
|
# Use `rerank-english-v2.0` model for validate,
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='rerank-english-v2.0',
|
model="rerank-english-v2.0", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -7,7 +7,12 @@ from cohere.responses import Chat, Generations
|
|||||||
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
|
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
|
||||||
from cohere.responses.generation import StreamingGenerations, StreamingText
|
from cohere.responses.generation import StreamingGenerations, StreamingText
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMMode,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -17,7 +22,12 @@ from model_providers.core.model_runtime.entities.message_entities import (
|
|||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
I18nObject,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -26,8 +36,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -37,11 +51,17 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
Model class for Cohere large language model.
|
Model class for Cohere large language model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -66,7 +86,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -76,11 +96,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
@ -95,9 +120,13 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if model_mode == LLMMode.CHAT:
|
if model_mode == LLMMode.CHAT:
|
||||||
return self._num_tokens_from_messages(model, credentials, prompt_messages)
|
return self._num_tokens_from_messages(
|
||||||
|
model, credentials, prompt_messages
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self._num_tokens_from_string(model, credentials, prompt_messages[0].content)
|
return self._num_tokens_from_string(
|
||||||
|
model, credentials, prompt_messages[0].content
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
@ -117,30 +146,37 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
self._chat_generate(
|
self._chat_generate(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
prompt_messages=[UserPromptMessage(content='ping')],
|
prompt_messages=[UserPromptMessage(content="ping")],
|
||||||
model_parameters={
|
model_parameters={
|
||||||
'max_tokens': 20,
|
"max_tokens": 20,
|
||||||
'temperature': 0,
|
"temperature": 0,
|
||||||
},
|
},
|
||||||
stream=False
|
stream=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._generate(
|
self._generate(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
prompt_messages=[UserPromptMessage(content='ping')],
|
prompt_messages=[UserPromptMessage(content="ping")],
|
||||||
model_parameters={
|
model_parameters={
|
||||||
'max_tokens': 20,
|
"max_tokens": 20,
|
||||||
'temperature': 0,
|
"temperature": 0,
|
||||||
},
|
},
|
||||||
stream=False
|
stream=False,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke llm model
|
Invoke llm model
|
||||||
|
|
||||||
@ -154,10 +190,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# initialize client
|
# initialize client
|
||||||
client = cohere.Client(credentials.get('api_key'))
|
client = cohere.Client(credentials.get("api_key"))
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
model_parameters['end_sequences'] = stop
|
model_parameters["end_sequences"] = stop
|
||||||
|
|
||||||
response = client.generate(
|
response = client.generate(
|
||||||
prompt=prompt_messages[0].content,
|
prompt=prompt_messages[0].content,
|
||||||
@ -167,13 +203,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_stream_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
|
def _handle_generate_response(
|
||||||
prompt_messages: list[PromptMessage]) \
|
self,
|
||||||
-> LLMResult:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: Generations,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm response
|
Handle llm response
|
||||||
|
|
||||||
@ -186,29 +230,34 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
assistant_text = response.generations[0].text
|
assistant_text = response.generations[0].text
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||||
content=assistant_text
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = response.meta['billed_units']['input_tokens']
|
prompt_tokens = response.meta["billed_units"]["input_tokens"]
|
||||||
completion_tokens = response.meta['billed_units']['output_tokens']
|
completion_tokens = response.meta["billed_units"]["output_tokens"]
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
|
def _handle_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: StreamingGenerations,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm stream response
|
Handle llm stream response
|
||||||
|
|
||||||
@ -218,7 +267,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: llm response chunk generator
|
:return: llm response chunk generator
|
||||||
"""
|
"""
|
||||||
index = 1
|
index = 1
|
||||||
full_assistant_content = ''
|
full_assistant_content = ""
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if isinstance(chunk, StreamingText):
|
if isinstance(chunk, StreamingText):
|
||||||
chunk = cast(StreamingText, chunk)
|
chunk = cast(StreamingText, chunk)
|
||||||
@ -228,9 +277,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||||
content=text
|
|
||||||
)
|
|
||||||
|
|
||||||
full_assistant_content += text
|
full_assistant_content += text
|
||||||
|
|
||||||
@ -240,33 +287,42 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
index += 1
|
index += 1
|
||||||
elif chunk is None:
|
elif chunk is None:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = response.meta['billed_units']['input_tokens']
|
prompt_tokens = response.meta["billed_units"]["input_tokens"]
|
||||||
completion_tokens = response.meta['billed_units']['output_tokens']
|
completion_tokens = response.meta["billed_units"]["output_tokens"]
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index,
|
||||||
message=AssistantPromptMessage(content=''),
|
message=AssistantPromptMessage(content=""),
|
||||||
finish_reason=response.finish_reason,
|
finish_reason=response.finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
def _chat_generate(self, model: str, credentials: dict,
|
def _chat_generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke llm chat model
|
Invoke llm chat model
|
||||||
|
|
||||||
@ -280,17 +336,23 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# initialize client
|
# initialize client
|
||||||
client = cohere.Client(credentials.get('api_key'))
|
client = cohere.Client(credentials.get("api_key"))
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
model_parameters['user_name'] = user
|
model_parameters["user_name"] = user
|
||||||
|
|
||||||
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
(
|
||||||
|
message,
|
||||||
|
chat_histories,
|
||||||
|
) = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
||||||
|
|
||||||
# chat model
|
# chat model
|
||||||
real_model = model
|
real_model = model
|
||||||
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
if (
|
||||||
real_model = model.removesuffix('-chat')
|
self.get_model_schema(model, credentials).fetch_from
|
||||||
|
== FetchFrom.PREDEFINED_MODEL
|
||||||
|
):
|
||||||
|
real_model = model.removesuffix("-chat")
|
||||||
|
|
||||||
response = client.chat(
|
response = client.chat(
|
||||||
message=message,
|
message=message,
|
||||||
@ -302,13 +364,22 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
|
return self._handle_chat_generate_stream_response(
|
||||||
|
model, credentials, response, prompt_messages, stop
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
|
return self._handle_chat_generate_response(
|
||||||
|
model, credentials, response, prompt_messages, stop
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
|
def _handle_chat_generate_response(
|
||||||
prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
|
self,
|
||||||
-> LLMResult:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: Chat,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm chat response
|
Handle llm chat response
|
||||||
|
|
||||||
@ -322,23 +393,25 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
assistant_text = response.text
|
assistant_text = response.text
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||||
content=assistant_text
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
prompt_tokens = self._num_tokens_from_messages(
|
||||||
completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message])
|
model, credentials, prompt_messages
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
model, credentials, [assistant_prompt_message]
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
# enforce stop tokens
|
# enforce stop tokens
|
||||||
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
|
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||||
content=assistant_text
|
|
||||||
)
|
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
@ -346,14 +419,19 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
system_fingerprint=response.preamble
|
system_fingerprint=response.preamble,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
|
def _handle_chat_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
stop: Optional[list[str]] = None) -> Generator:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: StreamingChat,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm chat stream response
|
Handle llm chat stream response
|
||||||
|
|
||||||
@ -364,18 +442,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: llm response chunk generator
|
:return: llm response chunk generator
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
|
def final_response(
|
||||||
preamble: Optional[str] = None) -> LLMResultChunk:
|
full_text: str,
|
||||||
|
index: int,
|
||||||
|
finish_reason: Optional[str] = None,
|
||||||
|
preamble: Optional[str] = None,
|
||||||
|
) -> LLMResultChunk:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
prompt_tokens = self._num_tokens_from_messages(
|
||||||
|
model, credentials, prompt_messages
|
||||||
full_assistant_prompt_message = AssistantPromptMessage(
|
)
|
||||||
content=full_text
|
|
||||||
|
full_assistant_prompt_message = AssistantPromptMessage(content=full_text)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
model, credentials, [full_assistant_prompt_message]
|
||||||
)
|
)
|
||||||
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
return LLMResultChunk(
|
return LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
@ -383,14 +469,14 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
system_fingerprint=preamble,
|
system_fingerprint=preamble,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index,
|
||||||
message=AssistantPromptMessage(content=''),
|
message=AssistantPromptMessage(content=""),
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
index = 1
|
index = 1
|
||||||
full_assistant_content = ''
|
full_assistant_content = ""
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if isinstance(chunk, StreamTextGeneration):
|
if isinstance(chunk, StreamTextGeneration):
|
||||||
chunk = cast(StreamTextGeneration, chunk)
|
chunk = cast(StreamTextGeneration, chunk)
|
||||||
@ -400,14 +486,12 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||||
content=text
|
|
||||||
)
|
|
||||||
|
|
||||||
# stop
|
# stop
|
||||||
# notice: This logic can only cover few stop scenarios
|
# notice: This logic can only cover few stop scenarios
|
||||||
if stop and text in stop:
|
if stop and text in stop:
|
||||||
yield final_response(full_assistant_content, index, 'stop')
|
yield final_response(full_assistant_content, index, "stop")
|
||||||
break
|
break
|
||||||
|
|
||||||
full_assistant_content += text
|
full_assistant_content += text
|
||||||
@ -418,17 +502,23 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
index += 1
|
index += 1
|
||||||
elif isinstance(chunk, StreamEnd):
|
elif isinstance(chunk, StreamEnd):
|
||||||
chunk = cast(StreamEnd, chunk)
|
chunk = cast(StreamEnd, chunk)
|
||||||
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
|
yield final_response(
|
||||||
|
full_assistant_content,
|
||||||
|
index,
|
||||||
|
chunk.finish_reason,
|
||||||
|
response.preamble,
|
||||||
|
)
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
|
def _convert_prompt_messages_to_message_and_chat_histories(
|
||||||
-> tuple[str, list[dict]]:
|
self, prompt_messages: list[PromptMessage]
|
||||||
|
) -> tuple[str, list[dict]]:
|
||||||
"""
|
"""
|
||||||
Convert prompt messages to message and chat histories
|
Convert prompt messages to message and chat histories
|
||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
@ -441,9 +531,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
# get latest message from chat histories and pop it
|
# get latest message from chat histories and pop it
|
||||||
if len(chat_histories) > 0:
|
if len(chat_histories) > 0:
|
||||||
latest_message = chat_histories.pop()
|
latest_message = chat_histories.pop()
|
||||||
message = latest_message['message']
|
message = latest_message["message"]
|
||||||
else:
|
else:
|
||||||
raise ValueError('Prompt messages is empty')
|
raise ValueError("Prompt messages is empty")
|
||||||
|
|
||||||
return message, chat_histories
|
return message, chat_histories
|
||||||
|
|
||||||
@ -456,10 +546,12 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
message_dict = {"role": "USER", "message": message.content}
|
message_dict = {"role": "USER", "message": message.content}
|
||||||
else:
|
else:
|
||||||
sub_message_text = ''
|
sub_message_text = ""
|
||||||
for message_content in message.content:
|
for message_content in message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
sub_message_text += message_content.data
|
sub_message_text += message_content.data
|
||||||
|
|
||||||
message_dict = {"role": "USER", "message": sub_message_text}
|
message_dict = {"role": "USER", "message": sub_message_text}
|
||||||
@ -487,47 +579,53 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: number of tokens
|
:return: number of tokens
|
||||||
"""
|
"""
|
||||||
# initialize client
|
# initialize client
|
||||||
client = cohere.Client(credentials.get('api_key'))
|
client = cohere.Client(credentials.get("api_key"))
|
||||||
|
|
||||||
response = client.tokenize(
|
response = client.tokenize(text=text, model=model)
|
||||||
text=text,
|
|
||||||
model=model
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.length
|
return response.length
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
|
def _num_tokens_from_messages(
|
||||||
|
self, model: str, credentials: dict, messages: list[PromptMessage]
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens Cohere model."""
|
"""Calculate num tokens Cohere model."""
|
||||||
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
|
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||||
message_strs = [f"{message['role']}: {message['message']}" for message in messages]
|
message_strs = [
|
||||||
|
f"{message['role']}: {message['message']}" for message in messages
|
||||||
|
]
|
||||||
message_str = "\n".join(message_strs)
|
message_str = "\n".join(message_strs)
|
||||||
|
|
||||||
real_model = model
|
real_model = model
|
||||||
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
if (
|
||||||
real_model = model.removesuffix('-chat')
|
self.get_model_schema(model, credentials).fetch_from
|
||||||
|
== FetchFrom.PREDEFINED_MODEL
|
||||||
|
):
|
||||||
|
real_model = model.removesuffix("-chat")
|
||||||
|
|
||||||
return self._num_tokens_from_string(real_model, credentials, message_str)
|
return self._num_tokens_from_string(real_model, credentials, message_str)
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> AIModelEntity:
|
||||||
"""
|
"""
|
||||||
Cohere supports fine-tuning of their models. This method returns the schema of the base model
|
Cohere supports fine-tuning of their models. This method returns the schema of the base model
|
||||||
but renamed to the fine-tuned model name.
|
but renamed to the fine-tuned model name.
|
||||||
|
|
||||||
:param model: model name
|
:param model: model name
|
||||||
:param credentials: credentials
|
:param credentials: credentials
|
||||||
|
|
||||||
:return: model schema
|
:return: model schema
|
||||||
"""
|
"""
|
||||||
# get model schema
|
# get model schema
|
||||||
models = self.predefined_models()
|
models = self.predefined_models()
|
||||||
model_map = {model.model: model for model in models}
|
model_map = {model.model: model for model in models}
|
||||||
|
|
||||||
mode = credentials.get('mode')
|
mode = credentials.get("mode")
|
||||||
|
|
||||||
if mode == 'chat':
|
if mode == "chat":
|
||||||
base_model_schema = model_map['command-light-chat']
|
base_model_schema = model_map["command-light-chat"]
|
||||||
else:
|
else:
|
||||||
base_model_schema = model_map['command-light']
|
base_model_schema = model_map["command-light"]
|
||||||
|
|
||||||
base_model_schema = cast(AIModelEntity, base_model_schema)
|
base_model_schema = cast(AIModelEntity, base_model_schema)
|
||||||
|
|
||||||
@ -537,18 +635,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans=model, en_US=model),
|
||||||
zh_Hans=model,
|
|
||||||
en_US=model
|
|
||||||
),
|
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
features=[feature for feature in base_model_schema_features],
|
features=[feature for feature in base_model_schema_features],
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
key: property for key, property in base_model_schema_model_properties.items()
|
key: property
|
||||||
|
for key, property in base_model_schema_model_properties.items()
|
||||||
},
|
},
|
||||||
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
|
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
|
||||||
pricing=base_model_schema.pricing
|
pricing=base_model_schema.pricing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
@ -564,14 +660,12 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [cohere.CohereConnectionError],
|
||||||
cohere.CohereConnectionError
|
|
||||||
],
|
|
||||||
InvokeServerUnavailableError: [],
|
InvokeServerUnavailableError: [],
|
||||||
InvokeRateLimitError: [],
|
InvokeRateLimitError: [],
|
||||||
InvokeAuthorizationError: [],
|
InvokeAuthorizationError: [],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
cohere.CohereAPIError,
|
cohere.CohereAPIError,
|
||||||
cohere.CohereError,
|
cohere.CohereError,
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,10 @@ from typing import Optional
|
|||||||
|
|
||||||
import cohere
|
import cohere
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
from model_providers.core.model_runtime.entities.rerank_entities import (
|
||||||
|
RerankDocument,
|
||||||
|
RerankResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -11,8 +14,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.rerank_model import (
|
||||||
|
RerankModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CohereRerankModel(RerankModel):
|
class CohereRerankModel(RerankModel):
|
||||||
@ -20,10 +27,16 @@ class CohereRerankModel(RerankModel):
|
|||||||
Model class for Cohere rerank model.
|
Model class for Cohere rerank model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
self,
|
||||||
user: Optional[str] = None) \
|
model: str,
|
||||||
-> RerankResult:
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
"""
|
"""
|
||||||
Invoke rerank model
|
Invoke rerank model
|
||||||
|
|
||||||
@ -37,26 +50,18 @@ class CohereRerankModel(RerankModel):
|
|||||||
:return: rerank result
|
:return: rerank result
|
||||||
"""
|
"""
|
||||||
if len(docs) == 0:
|
if len(docs) == 0:
|
||||||
return RerankResult(
|
return RerankResult(model=model, docs=docs)
|
||||||
model=model,
|
|
||||||
docs=docs
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize client
|
# initialize client
|
||||||
client = cohere.Client(credentials.get('api_key'))
|
client = cohere.Client(credentials.get("api_key"))
|
||||||
results = client.rerank(
|
results = client.rerank(query=query, documents=docs, model=model, top_n=top_n)
|
||||||
query=query,
|
|
||||||
documents=docs,
|
|
||||||
model=model,
|
|
||||||
top_n=top_n
|
|
||||||
)
|
|
||||||
|
|
||||||
rerank_documents = []
|
rerank_documents = []
|
||||||
for idx, result in enumerate(results):
|
for idx, result in enumerate(results):
|
||||||
# format document
|
# format document
|
||||||
rerank_document = RerankDocument(
|
rerank_document = RerankDocument(
|
||||||
index=result.index,
|
index=result.index,
|
||||||
text=result.document['text'],
|
text=result.document["text"],
|
||||||
score=result.relevance_score,
|
score=result.relevance_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -67,10 +72,7 @@ class CohereRerankModel(RerankModel):
|
|||||||
else:
|
else:
|
||||||
rerank_documents.append(rerank_document)
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
return RerankResult(
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
model=model,
|
|
||||||
docs=rerank_documents
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -91,7 +93,7 @@ class CohereRerankModel(RerankModel):
|
|||||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
],
|
],
|
||||||
score_threshold=0.8
|
score_threshold=0.8,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
@ -116,5 +118,5 @@ class CohereRerankModel(RerankModel):
|
|||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
cohere.CohereAPIError,
|
cohere.CohereAPIError,
|
||||||
cohere.CohereError,
|
cohere.CohereError,
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,7 +6,10 @@ import numpy as np
|
|||||||
from cohere.responses import Tokens
|
from cohere.responses import Tokens
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
EmbeddingUsage,
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -15,8 +18,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CohereTextEmbeddingModel(TextEmbeddingModel):
|
class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||||
@ -24,9 +31,13 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
Model class for Cohere text embedding model.
|
Model class for Cohere text embedding model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
texts: list[str], user: Optional[str] = None) \
|
self,
|
||||||
-> TextEmbeddingResult:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding model
|
Invoke text embedding model
|
||||||
|
|
||||||
@ -47,13 +58,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
|
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
tokenize_response = self._tokenize(
|
tokenize_response = self._tokenize(
|
||||||
model=model,
|
model=model, credentials=credentials, text=text
|
||||||
credentials=credentials,
|
|
||||||
text=text
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for j in range(0, tokenize_response.length, context_size):
|
for j in range(0, tokenize_response.length, context_size):
|
||||||
tokens += [tokenize_response.token_strings[j: j + context_size]]
|
tokens += [tokenize_response.token_strings[j : j + context_size]]
|
||||||
indices += [i]
|
indices += [i]
|
||||||
|
|
||||||
batched_embeddings = []
|
batched_embeddings = []
|
||||||
@ -64,7 +73,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
texts=["".join(token) for token in tokens[i: i + max_chunks]]
|
texts=["".join(token) for token in tokens[i : i + max_chunks]],
|
||||||
)
|
)
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
used_tokens += embedding_used_tokens
|
||||||
@ -80,9 +89,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
_result = results[i]
|
_result = results[i]
|
||||||
if len(_result) == 0:
|
if len(_result) == 0:
|
||||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
model=model,
|
model=model, credentials=credentials, texts=[" "]
|
||||||
credentials=credentials,
|
|
||||||
texts=[" "]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
used_tokens += embedding_used_tokens
|
||||||
@ -93,16 +100,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
|
|
||||||
# calc usage
|
# calc usage
|
||||||
usage = self._calc_response_usage(
|
usage = self._calc_response_usage(
|
||||||
model=model,
|
model=model, credentials=credentials, tokens=used_tokens
|
||||||
credentials=credentials,
|
|
||||||
tokens=used_tokens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||||
embeddings=embeddings,
|
|
||||||
usage=usage,
|
|
||||||
model=model
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
"""
|
"""
|
||||||
@ -116,13 +117,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
if len(texts) == 0:
|
if len(texts) == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
full_text = ' '.join(texts)
|
full_text = " ".join(texts)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self._tokenize(
|
response = self._tokenize(
|
||||||
model=model,
|
model=model, credentials=credentials, text=full_text
|
||||||
credentials=credentials,
|
|
||||||
text=full_text
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
@ -141,12 +140,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
return Tokens([], [], {})
|
return Tokens([], [], {})
|
||||||
|
|
||||||
# initialize client
|
# initialize client
|
||||||
client = cohere.Client(credentials.get('api_key'))
|
client = cohere.Client(credentials.get("api_key"))
|
||||||
|
|
||||||
response = client.tokenize(
|
response = client.tokenize(text=text, model=model)
|
||||||
text=text,
|
|
||||||
model=model
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -160,15 +156,13 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# call embedding model
|
# call embedding model
|
||||||
self._embedding_invoke(
|
self._embedding_invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
model=model,
|
|
||||||
credentials=credentials,
|
|
||||||
texts=['ping']
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
|
def _embedding_invoke(
|
||||||
|
self, model: str, credentials: dict, texts: list[str]
|
||||||
|
) -> tuple[list[list[float]], int]:
|
||||||
"""
|
"""
|
||||||
Invoke embedding model
|
Invoke embedding model
|
||||||
|
|
||||||
@ -178,18 +172,20 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return: embeddings and used tokens
|
:return: embeddings and used tokens
|
||||||
"""
|
"""
|
||||||
# initialize client
|
# initialize client
|
||||||
client = cohere.Client(credentials.get('api_key'))
|
client = cohere.Client(credentials.get("api_key"))
|
||||||
|
|
||||||
# call embedding model
|
# call embedding model
|
||||||
response = client.embed(
|
response = client.embed(
|
||||||
texts=texts,
|
texts=texts,
|
||||||
model=model,
|
model=model,
|
||||||
input_type='search_document' if len(texts) > 1 else 'search_query'
|
input_type="search_document" if len(texts) > 1 else "search_query",
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.embeddings, response.meta['billed_units']['input_tokens']
|
return response.embeddings, response.meta["billed_units"]["input_tokens"]
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
|
||||||
@ -203,7 +199,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -214,7 +210,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
@ -230,14 +226,12 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [cohere.CohereConnectionError],
|
||||||
cohere.CohereConnectionError
|
|
||||||
],
|
|
||||||
InvokeServerUnavailableError: [],
|
InvokeServerUnavailableError: [],
|
||||||
InvokeRateLimitError: [],
|
InvokeRateLimitError: [],
|
||||||
InvokeAuthorizationError: [],
|
InvokeAuthorizationError: [],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
cohere.CohereAPIError,
|
cohere.CohereAPIError,
|
||||||
cohere.CohereError,
|
cohere.CohereError,
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -21,11 +25,12 @@ class GoogleProvider(ModelProvider):
|
|||||||
|
|
||||||
# Use `gemini-pro` model for validate,
|
# Use `gemini-pro` model for validate,
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='gemini-pro',
|
model="gemini-pro", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -5,10 +5,19 @@ from typing import Optional, Union
|
|||||||
import google.api_core.exceptions as exceptions
|
import google.api_core.exceptions as exceptions
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import google.generativeai.client as client
|
import google.generativeai.client as client
|
||||||
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
|
from google.generativeai.types import (
|
||||||
|
ContentType,
|
||||||
|
GenerateContentResponse,
|
||||||
|
HarmBlockThreshold,
|
||||||
|
HarmCategory,
|
||||||
|
)
|
||||||
from google.generativeai.types.content_types import to_part
|
from google.generativeai.types.content_types import to_part
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -26,8 +35,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -42,12 +55,17 @@ if you are not sure about the structure.
|
|||||||
|
|
||||||
|
|
||||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
|
def _invoke(
|
||||||
def _invoke(self, model: str, credentials: dict,
|
self,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
model: str,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
credentials: dict,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
prompt_messages: list[PromptMessage],
|
||||||
-> Union[LLMResult, Generator]:
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -62,10 +80,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# invoke model
|
# invoke model
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate(
|
||||||
|
model, credentials, prompt_messages, model_parameters, stop, stream, user
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
@ -89,8 +114,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
messages = messages.copy() # don't mutate the original list
|
messages = messages.copy() # don't mutate the original list
|
||||||
|
|
||||||
text = "".join(
|
text = "".join(
|
||||||
self._convert_one_message_to_text(message)
|
self._convert_one_message_to_text(message) for message in messages
|
||||||
for message in messages
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return text.rstrip()
|
return text.rstrip()
|
||||||
@ -106,16 +130,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
ping_message = PromptMessage(content="ping", role="system")
|
ping_message = PromptMessage(content="ping", role="system")
|
||||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
self._generate(
|
||||||
|
model, credentials, [ping_message], {"max_tokens_to_sample": 5}
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _generate(
|
||||||
def _generate(self, model: str, credentials: dict,
|
self,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
model: str,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True,
|
credentials: dict,
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -129,14 +160,14 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
config_kwargs = model_parameters.copy()
|
config_kwargs = model_parameters.copy()
|
||||||
config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
|
config_kwargs["max_output_tokens"] = config_kwargs.pop(
|
||||||
|
"max_tokens_to_sample", None
|
||||||
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
config_kwargs["stop_sequences"] = stop
|
config_kwargs["stop_sequences"] = stop
|
||||||
|
|
||||||
google_model = genai.GenerativeModel(
|
google_model = genai.GenerativeModel(model_name=model)
|
||||||
model_name=model
|
|
||||||
)
|
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
@ -146,14 +177,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
content = self._format_message_to_glm_content(last_msg)
|
content = self._format_message_to_glm_content(last_msg)
|
||||||
history.append(content)
|
history.append(content)
|
||||||
else:
|
else:
|
||||||
for msg in prompt_messages: # makes message roles strictly alternating
|
for msg in prompt_messages: # makes message roles strictly alternating
|
||||||
content = self._format_message_to_glm_content(msg)
|
content = self._format_message_to_glm_content(msg)
|
||||||
if history and history[-1]["role"] == content["role"]:
|
if history and history[-1]["role"] == content["role"]:
|
||||||
history[-1]["parts"].extend(content["parts"])
|
history[-1]["parts"].extend(content["parts"])
|
||||||
else:
|
else:
|
||||||
history.append(content)
|
history.append(content)
|
||||||
|
|
||||||
|
|
||||||
# Create a new ClientManager with tenant's API key
|
# Create a new ClientManager with tenant's API key
|
||||||
new_client_manager = client._ClientManager()
|
new_client_manager = client._ClientManager()
|
||||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
new_client_manager.configure(api_key=credentials["google_api_key"])
|
||||||
@ -161,7 +191,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
google_model._client = new_custom_client
|
google_model._client = new_custom_client
|
||||||
|
|
||||||
safety_settings={
|
safety_settings = {
|
||||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
@ -170,20 +200,27 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
response = google_model.generate_content(
|
response = google_model.generate_content(
|
||||||
contents=history,
|
contents=history,
|
||||||
generation_config=genai.types.GenerationConfig(
|
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||||
**config_kwargs
|
|
||||||
),
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
safety_settings=safety_settings
|
safety_settings=safety_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_stream_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse,
|
def _handle_generate_response(
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: GenerateContentResponse,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm response
|
Handle llm response
|
||||||
|
|
||||||
@ -194,16 +231,18 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: llm response
|
:return: llm response
|
||||||
"""
|
"""
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=response.text)
|
||||||
content=response.text
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
completion_tokens = self.get_num_tokens(
|
||||||
|
model, credentials, [assistant_prompt_message]
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
result = LLMResult(
|
result = LLMResult(
|
||||||
@ -215,8 +254,13 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse,
|
def _handle_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: GenerateContentResponse,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm stream response
|
Handle llm stream response
|
||||||
|
|
||||||
@ -232,28 +276,29 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=content if content else '',
|
content=content if content else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response._done:
|
if not response._done:
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=index, message=assistant_prompt_message
|
||||||
message=assistant_prompt_message
|
),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
completion_tokens = self.get_num_tokens(
|
||||||
|
model, credentials, [assistant_prompt_message]
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
@ -262,8 +307,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
index=index,
|
index=index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=chunk.candidates[0].finish_reason,
|
finish_reason=chunk.candidates[0].finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||||
@ -302,21 +347,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
if (isinstance(message.content, str)):
|
if isinstance(message.content, str):
|
||||||
parts.append(to_part(message.content))
|
parts.append(to_part(message.content))
|
||||||
else:
|
else:
|
||||||
for c in message.content:
|
for c in message.content:
|
||||||
if c.type == PromptMessageContentType.TEXT:
|
if c.type == PromptMessageContentType.TEXT:
|
||||||
parts.append(to_part(c.data))
|
parts.append(to_part(c.data))
|
||||||
else:
|
else:
|
||||||
metadata, data = c.data.split(',', 1)
|
metadata, data = c.data.split(",", 1)
|
||||||
mime_type = metadata.split(';', 1)[0].split(':')[1]
|
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||||
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
|
blob = {"inline_data": {"mime_type": mime_type, "data": data}}
|
||||||
parts.append(blob)
|
parts.append(blob)
|
||||||
|
|
||||||
glm_content = {
|
glm_content = {
|
||||||
"role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model",
|
"role": "user"
|
||||||
"parts": parts
|
if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM)
|
||||||
|
else "model",
|
||||||
|
"parts": parts,
|
||||||
}
|
}
|
||||||
|
|
||||||
return glm_content
|
return glm_content
|
||||||
@ -332,25 +379,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [exceptions.RetryError],
|
||||||
exceptions.RetryError
|
|
||||||
],
|
|
||||||
InvokeServerUnavailableError: [
|
InvokeServerUnavailableError: [
|
||||||
exceptions.ServiceUnavailable,
|
exceptions.ServiceUnavailable,
|
||||||
exceptions.InternalServerError,
|
exceptions.InternalServerError,
|
||||||
exceptions.BadGateway,
|
exceptions.BadGateway,
|
||||||
exceptions.GatewayTimeout,
|
exceptions.GatewayTimeout,
|
||||||
exceptions.DeadlineExceeded
|
exceptions.DeadlineExceeded,
|
||||||
],
|
],
|
||||||
InvokeRateLimitError: [
|
InvokeRateLimitError: [
|
||||||
exceptions.ResourceExhausted,
|
exceptions.ResourceExhausted,
|
||||||
exceptions.TooManyRequests
|
exceptions.TooManyRequests,
|
||||||
],
|
],
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
exceptions.Unauthenticated,
|
exceptions.Unauthenticated,
|
||||||
exceptions.PermissionDenied,
|
exceptions.PermissionDenied,
|
||||||
exceptions.Unauthenticated,
|
exceptions.Unauthenticated,
|
||||||
exceptions.Forbidden
|
exceptions.Forbidden,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
exceptions.BadRequest,
|
exceptions.BadRequest,
|
||||||
@ -366,5 +411,5 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
exceptions.PreconditionFailed,
|
exceptions.PreconditionFailed,
|
||||||
exceptions.RequestRangeNotSatisfiable,
|
exceptions.RequestRangeNotSatisfiable,
|
||||||
exceptions.Cancelled,
|
exceptions.Cancelled,
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,13 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class GroqProvider(ModelProvider):
|
|
||||||
|
|
||||||
|
class GroqProvider(ModelProvider):
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
@ -19,11 +23,12 @@ class GroqProvider(ModelProvider):
|
|||||||
model_instance = self.get_model_instance(ModelType.LLM)
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='llama2-70b-4096',
|
model="llama2-70b-4096", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -2,18 +2,31 @@ from collections.abc import Generator
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
|
||||||
|
OAIAPICompatLargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
self._add_custom_parameters(credentials)
|
self._add_custom_parameters(credentials)
|
||||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
return super()._invoke(
|
||||||
|
model, credentials, prompt_messages, model_parameters, tools, stop, stream
|
||||||
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
self._add_custom_parameters(credentials)
|
self._add_custom_parameters(credentials)
|
||||||
@ -21,6 +34,5 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_custom_parameters(credentials: dict) -> None:
|
def _add_custom_parameters(credentials: dict) -> None:
|
||||||
credentials['mode'] = 'chat'
|
credentials["mode"] = "chat"
|
||||||
credentials['endpoint_url'] = 'https://api.groq.com/openai/v1'
|
credentials["endpoint_url"] = "https://api.groq.com/openai/v1"
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,12 @@
|
|||||||
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
|
from huggingface_hub.utils import BadRequestError, HfHubHTTPError
|
||||||
|
|
||||||
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _CommonHuggingfaceHub:
|
class _CommonHuggingfaceHub:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
return {
|
return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]}
|
||||||
InvokeBadRequestError: [
|
|
||||||
HfHubHTTPError,
|
|
||||||
BadRequestError
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceHubProvider(ModelProvider):
|
class HuggingfaceHubProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -7,7 +7,12 @@ from huggingface_hub.utils import BadRequestError
|
|||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
from model_providers.core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMMode,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -23,22 +28,35 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
ParameterRule,
|
ParameterRule,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.huggingface_hub._common import (
|
||||||
|
_CommonHuggingfaceHub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel):
|
class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
def _invoke(
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
|
self,
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||||
|
|
||||||
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
|
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||||
|
model = credentials["huggingfacehub_endpoint_url"]
|
||||||
|
|
||||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
if "baichuan" in model.lower():
|
||||||
model = credentials['huggingfacehub_endpoint_url']
|
|
||||||
|
|
||||||
if 'baichuan' in model.lower():
|
|
||||||
stream = False
|
stream = False
|
||||||
|
|
||||||
response = client.text_generation(
|
response = client.text_generation(
|
||||||
@ -47,71 +65,97 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
model=model,
|
model=model,
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
**model_parameters)
|
**model_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, prompt_messages, response)
|
return self._handle_generate_stream_response(
|
||||||
|
model, credentials, prompt_messages, response
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, prompt_messages, response)
|
return self._handle_generate_response(
|
||||||
|
model, credentials, prompt_messages, response
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||||
return self._get_num_tokens_by_gpt2(prompt)
|
return self._get_num_tokens_by_gpt2(prompt)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
try:
|
try:
|
||||||
if 'huggingfacehub_api_type' not in credentials:
|
if "huggingfacehub_api_type" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Endpoint Type must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'):
|
if credentials["huggingfacehub_api_type"] not in (
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.')
|
"inference_endpoints",
|
||||||
|
"hosted_inference_api",
|
||||||
|
):
|
||||||
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Endpoint Type is invalid."
|
||||||
|
)
|
||||||
|
|
||||||
if 'huggingfacehub_api_token' not in credentials:
|
if "huggingfacehub_api_token" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Access Token must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
if "huggingfacehub_endpoint_url" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Endpoint URL must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
if 'task_type' not in credentials:
|
if "task_type" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
|
"Huggingface Hub Task Type must be provided."
|
||||||
credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'],
|
)
|
||||||
model)
|
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
|
||||||
|
credentials["task_type"] = self._get_hosted_model_task_type(
|
||||||
|
credentials["huggingfacehub_api_token"], model
|
||||||
|
)
|
||||||
|
|
||||||
if credentials['task_type'] not in ("text2text-generation", "text-generation"):
|
if credentials["task_type"] not in (
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, '
|
"text2text-generation",
|
||||||
'text-generation.')
|
"text-generation",
|
||||||
|
):
|
||||||
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Task Type must be one of text2text-generation, "
|
||||||
|
"text-generation."
|
||||||
|
)
|
||||||
|
|
||||||
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
|
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||||
|
|
||||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||||
model = credentials['huggingfacehub_endpoint_url']
|
model = credentials["huggingfacehub_endpoint_url"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client.text_generation(
|
client.text_generation(prompt="Who are you?", stream=True, model=model)
|
||||||
prompt='Who are you?',
|
|
||||||
stream=True,
|
|
||||||
model=model)
|
|
||||||
except BadRequestError as e:
|
except BadRequestError as e:
|
||||||
raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. '
|
raise CredentialsValidateFailedError(
|
||||||
'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.')
|
"Only available for models running on with the `text-generation-inference`. "
|
||||||
|
"To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference."
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(en_US=model),
|
||||||
en_US=model
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
model_properties={
|
model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value},
|
||||||
ModelPropertyKey.MODE: LLMMode.COMPLETION.value
|
parameter_rules=self._get_customizable_model_parameter_rules(),
|
||||||
},
|
|
||||||
parameter_rules=self._get_customizable_model_parameter_rules()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
@ -119,26 +163,27 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_customizable_model_parameter_rules() -> list[ParameterRule]:
|
def _get_customizable_model_parameter_rules() -> list[ParameterRule]:
|
||||||
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(
|
temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(
|
||||||
DefaultParameterName.TEMPERATURE).copy()
|
DefaultParameterName.TEMPERATURE
|
||||||
temperature_rule_dict['name'] = 'temperature'
|
).copy()
|
||||||
|
temperature_rule_dict["name"] = "temperature"
|
||||||
temperature_rule = ParameterRule(**temperature_rule_dict)
|
temperature_rule = ParameterRule(**temperature_rule_dict)
|
||||||
temperature_rule.default = 0.5
|
temperature_rule.default = 0.5
|
||||||
|
|
||||||
top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy()
|
top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy()
|
||||||
top_p_rule_dict['name'] = 'top_p'
|
top_p_rule_dict["name"] = "top_p"
|
||||||
top_p_rule = ParameterRule(**top_p_rule_dict)
|
top_p_rule = ParameterRule(**top_p_rule_dict)
|
||||||
top_p_rule.default = 0.5
|
top_p_rule.default = 0.5
|
||||||
|
|
||||||
top_k_rule = ParameterRule(
|
top_k_rule = ParameterRule(
|
||||||
name='top_k',
|
name="top_k",
|
||||||
label={
|
label={
|
||||||
'en_US': 'Top K',
|
"en_US": "Top K",
|
||||||
'zh_Hans': 'Top K',
|
"zh_Hans": "Top K",
|
||||||
},
|
},
|
||||||
type='int',
|
type="int",
|
||||||
help={
|
help={
|
||||||
'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.',
|
"en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
|
||||||
'zh_Hans': '保留的最高概率词汇标记的数量。',
|
"zh_Hans": "保留的最高概率词汇标记的数量。",
|
||||||
},
|
},
|
||||||
required=False,
|
required=False,
|
||||||
default=2,
|
default=2,
|
||||||
@ -148,15 +193,15 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
|||||||
)
|
)
|
||||||
|
|
||||||
max_new_tokens = ParameterRule(
|
max_new_tokens = ParameterRule(
|
||||||
name='max_new_tokens',
|
name="max_new_tokens",
|
||||||
label={
|
label={
|
||||||
'en_US': 'Max New Tokens',
|
"en_US": "Max New Tokens",
|
||||||
'zh_Hans': '最大新标记',
|
"zh_Hans": "最大新标记",
|
||||||
},
|
},
|
||||||
type='int',
|
type="int",
|
||||||
help={
|
help={
|
||||||
'en_US': 'Maximum number of generated tokens.',
|
"en_US": "Maximum number of generated tokens.",
|
||||||
'zh_Hans': '生成的标记的最大数量。',
|
"zh_Hans": "生成的标记的最大数量。",
|
||||||
},
|
},
|
||||||
required=False,
|
required=False,
|
||||||
default=20,
|
default=20,
|
||||||
@ -166,42 +211,51 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
|||||||
)
|
)
|
||||||
|
|
||||||
seed = ParameterRule(
|
seed = ParameterRule(
|
||||||
name='seed',
|
name="seed",
|
||||||
label={
|
label={
|
||||||
'en_US': 'Random sampling seed',
|
"en_US": "Random sampling seed",
|
||||||
'zh_Hans': '随机采样种子',
|
"zh_Hans": "随机采样种子",
|
||||||
},
|
},
|
||||||
type='int',
|
type="int",
|
||||||
help={
|
help={
|
||||||
'en_US': 'Random sampling seed.',
|
"en_US": "Random sampling seed.",
|
||||||
'zh_Hans': '随机采样种子。',
|
"zh_Hans": "随机采样种子。",
|
||||||
},
|
},
|
||||||
required=False,
|
required=False,
|
||||||
precision=0,
|
precision=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
repetition_penalty = ParameterRule(
|
repetition_penalty = ParameterRule(
|
||||||
name='repetition_penalty',
|
name="repetition_penalty",
|
||||||
label={
|
label={
|
||||||
'en_US': 'Repetition Penalty',
|
"en_US": "Repetition Penalty",
|
||||||
'zh_Hans': '重复惩罚',
|
"zh_Hans": "重复惩罚",
|
||||||
},
|
},
|
||||||
type='float',
|
type="float",
|
||||||
help={
|
help={
|
||||||
'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.',
|
"en_US": "The parameter for repetition penalty. 1.0 means no penalty.",
|
||||||
'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。',
|
"zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。",
|
||||||
},
|
},
|
||||||
required=False,
|
required=False,
|
||||||
precision=1,
|
precision=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty]
|
return [
|
||||||
|
temperature_rule,
|
||||||
|
top_k_rule,
|
||||||
|
top_p_rule,
|
||||||
|
max_new_tokens,
|
||||||
|
seed,
|
||||||
|
repetition_penalty,
|
||||||
|
]
|
||||||
|
|
||||||
def _handle_generate_stream_response(self,
|
def _handle_generate_stream_response(
|
||||||
model: str,
|
self,
|
||||||
credentials: dict,
|
model: str,
|
||||||
prompt_messages: list[PromptMessage],
|
credentials: dict,
|
||||||
response: Generator) -> Generator:
|
prompt_messages: list[PromptMessage],
|
||||||
|
response: Generator,
|
||||||
|
) -> Generator:
|
||||||
index = -1
|
index = -1
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
# skip special tokens
|
# skip special tokens
|
||||||
@ -210,15 +264,17 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
|||||||
|
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text)
|
||||||
content=chunk.token.text
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunk.details:
|
if chunk.details:
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
completion_tokens = self.get_num_tokens(
|
||||||
|
model, credentials, [assistant_prompt_message]
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
@ -240,20 +296,28 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
|
def _handle_generate_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
response: any,
|
||||||
|
) -> LLMResult:
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
content = response
|
content = response
|
||||||
else:
|
else:
|
||||||
content = response.generated_text
|
content = response.generated_text
|
||||||
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=content)
|
||||||
content=content
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
completion_tokens = self.get_num_tokens(
|
||||||
|
model, credentials, [assistant_prompt_message]
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
result = LLMResult(
|
result = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
@ -270,15 +334,22 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if not model_info:
|
if not model_info:
|
||||||
raise ValueError(f'Model {model_name} not found.')
|
raise ValueError(f"Model {model_name} not found.")
|
||||||
|
|
||||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
if (
|
||||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
"inference" in model_info.cardData
|
||||||
|
and not model_info.cardData["inference"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Inference API has been turned off for this model {model_name}."
|
||||||
|
)
|
||||||
|
|
||||||
valid_tasks = ("text2text-generation", "text-generation")
|
valid_tasks = ("text2text-generation", "text-generation")
|
||||||
if model_info.pipeline_tag not in valid_tasks:
|
if model_info.pipeline_tag not in valid_tasks:
|
||||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
raise ValueError(
|
||||||
f"must be one of {valid_tasks}.")
|
f"Model {model_name} is not a valid task, "
|
||||||
|
f"must be one of {valid_tasks}."
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CredentialsValidateFailedError(f"{str(e)}")
|
raise CredentialsValidateFailedError(f"{str(e)}")
|
||||||
|
|
||||||
@ -288,8 +359,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
|||||||
messages = messages.copy() # don't mutate the original list
|
messages = messages.copy() # don't mutate the original list
|
||||||
|
|
||||||
text = "".join(
|
text = "".join(
|
||||||
self._convert_one_message_to_text(message)
|
self._convert_one_message_to_text(message) for message in messages
|
||||||
for message in messages
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return text.rstrip()
|
return text.rstrip()
|
||||||
|
|||||||
@ -7,35 +7,51 @@ import requests
|
|||||||
from huggingface_hub import HfApi, InferenceClient
|
from huggingface_hub import HfApi, InferenceClient
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
AIModelEntity,
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
FetchFrom,
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
ModelType,
|
||||||
from model_providers.core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub
|
PriceType,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
EmbeddingUsage,
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.huggingface_hub._common import (
|
||||||
|
_CommonHuggingfaceHub,
|
||||||
|
)
|
||||||
|
|
||||||
HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
|
HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/"
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
|
class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
|
||||||
|
def _invoke(
|
||||||
def _invoke(self, model: str, credentials: dict, texts: list[str],
|
self,
|
||||||
user: Optional[str] = None) -> TextEmbeddingResult:
|
model: str,
|
||||||
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
|
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||||
|
|
||||||
execute_model = model
|
execute_model = model
|
||||||
|
|
||||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||||
execute_model = credentials['huggingfacehub_endpoint_url']
|
execute_model = credentials["huggingfacehub_endpoint_url"]
|
||||||
|
|
||||||
output = client.post(
|
output = client.post(
|
||||||
json={
|
json={
|
||||||
"inputs": texts,
|
"inputs": texts,
|
||||||
"options": {
|
"options": {"wait_for_model": False, "use_cache": False},
|
||||||
"wait_for_model": False,
|
|
||||||
"use_cache": False
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
model=execute_model)
|
model=execute_model,
|
||||||
|
)
|
||||||
|
|
||||||
embeddings = json.loads(output.decode())
|
embeddings = json.loads(output.decode())
|
||||||
|
|
||||||
@ -43,9 +59,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
|||||||
usage = self._calc_response_usage(model, credentials, tokens)
|
usage = self._calc_response_usage(model, credentials, tokens)
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
return TextEmbeddingResult(
|
||||||
embeddings=self._mean_pooling(embeddings),
|
embeddings=self._mean_pooling(embeddings), usage=usage, model=model
|
||||||
usage=usage,
|
|
||||||
model=model
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
@ -56,52 +70,64 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
|||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
try:
|
try:
|
||||||
if 'huggingfacehub_api_type' not in credentials:
|
if "huggingfacehub_api_type" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Endpoint Type must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
if 'huggingfacehub_api_token' not in credentials:
|
if "huggingfacehub_api_token" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub API Token must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
if credentials["huggingfacehub_api_type"] == "inference_endpoints":
|
||||||
if 'huggingface_namespace' not in credentials:
|
if "huggingface_namespace" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub User Name / Organization Name must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
if "huggingfacehub_endpoint_url" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Endpoint URL must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
if 'task_type' not in credentials:
|
if "task_type" not in credentials:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Task Type must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
if credentials['task_type'] != 'feature-extraction':
|
if credentials["task_type"] != "feature-extraction":
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Task Type is invalid."
|
||||||
|
)
|
||||||
|
|
||||||
self._check_endpoint_url_model_repository_name(credentials, model)
|
self._check_endpoint_url_model_repository_name(credentials, model)
|
||||||
|
|
||||||
model = credentials['huggingfacehub_endpoint_url']
|
model = credentials["huggingfacehub_endpoint_url"]
|
||||||
|
|
||||||
elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
|
elif credentials["huggingfacehub_api_type"] == "hosted_inference_api":
|
||||||
self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'],
|
self._check_hosted_model_task_type(
|
||||||
model)
|
credentials["huggingfacehub_api_token"], model
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Huggingface Hub Endpoint Type is invalid."
|
||||||
|
)
|
||||||
|
|
||||||
client = InferenceClient(token=credentials['huggingfacehub_api_token'])
|
client = InferenceClient(token=credentials["huggingfacehub_api_token"])
|
||||||
client.feature_extraction(text='hello world', model=model)
|
client.feature_extraction(text="hello world", model=model)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> Optional[AIModelEntity]:
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(en_US=model),
|
||||||
en_US=model
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model_properties={
|
model_properties={"context_size": 10000, "max_chunks": 1},
|
||||||
'context_size': 10000,
|
|
||||||
'max_chunks': 1
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
@ -118,34 +144,47 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
# For example two: List[List[List[float]]], need to mean_pooling.
|
# For example two: List[List[List[float]]], need to mean_pooling.
|
||||||
sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings]
|
sentence_embeddings = [
|
||||||
|
np.mean(embedding[0], axis=0).tolist() for embedding in embeddings
|
||||||
|
]
|
||||||
return sentence_embeddings
|
return sentence_embeddings
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str) -> None:
|
def _check_hosted_model_task_type(
|
||||||
|
huggingfacehub_api_token: str, model_name: str
|
||||||
|
) -> None:
|
||||||
hf_api = HfApi(token=huggingfacehub_api_token)
|
hf_api = HfApi(token=huggingfacehub_api_token)
|
||||||
model_info = hf_api.model_info(repo_id=model_name)
|
model_info = hf_api.model_info(repo_id=model_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not model_info:
|
if not model_info:
|
||||||
raise ValueError(f'Model {model_name} not found.')
|
raise ValueError(f"Model {model_name} not found.")
|
||||||
|
|
||||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
if (
|
||||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
"inference" in model_info.cardData
|
||||||
|
and not model_info.cardData["inference"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Inference API has been turned off for this model {model_name}."
|
||||||
|
)
|
||||||
|
|
||||||
valid_tasks = "feature-extraction"
|
valid_tasks = "feature-extraction"
|
||||||
if model_info.pipeline_tag not in valid_tasks:
|
if model_info.pipeline_tag not in valid_tasks:
|
||||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
raise ValueError(
|
||||||
f"must be one of {valid_tasks}.")
|
f"Model {model_name} is not a valid task, "
|
||||||
|
f"must be one of {valid_tasks}."
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CredentialsValidateFailedError(f"{str(e)}")
|
raise CredentialsValidateFailedError(f"{str(e)}")
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
input_price_info = self.get_price(
|
input_price_info = self.get_price(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -156,7 +195,7 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
@ -166,25 +205,29 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
|||||||
try:
|
try:
|
||||||
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
|
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||||
'Content-Type': 'application/json'
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(url=url, headers=headers)
|
response = requests.get(url=url, headers=headers)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise ValueError('User Name or Organization Name is invalid.')
|
raise ValueError("User Name or Organization Name is invalid.")
|
||||||
|
|
||||||
model_repository_name = ''
|
model_repository_name = ""
|
||||||
|
|
||||||
for item in response.json().get("items", []):
|
for item in response.json().get("items", []):
|
||||||
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
|
if (
|
||||||
|
item.get("status", {}).get("url")
|
||||||
|
== credentials["huggingfacehub_endpoint_url"]
|
||||||
|
):
|
||||||
model_repository_name = item.get("model", {}).get("repository")
|
model_repository_name = item.get("model", {}).get("repository")
|
||||||
break
|
break
|
||||||
|
|
||||||
if model_repository_name != model_name:
|
if model_repository_name != model_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
|
f"Model Name {model_name} is invalid. Please check it on the inference endpoints console."
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|||||||
@ -1,14 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class JinaProvider(ModelProvider):
|
class JinaProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
@ -22,11 +25,12 @@ class JinaProvider(ModelProvider):
|
|||||||
# Use `jina-embeddings-v2-base-en` model for validate,
|
# Use `jina-embeddings-v2-base-en` model for validate,
|
||||||
# no matter what model you pass in, text completion model or chat model
|
# no matter what model you pass in, text completion model or chat model
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='jina-embeddings-v2-base-en',
|
model="jina-embeddings-v2-base-en", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -2,7 +2,10 @@ from typing import Optional
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
from model_providers.core.model_runtime.entities.rerank_entities import (
|
||||||
|
RerankDocument,
|
||||||
|
RerankResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -11,8 +14,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.rerank_model import (
|
||||||
|
RerankModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JinaRerankModel(RerankModel):
|
class JinaRerankModel(RerankModel):
|
||||||
@ -20,9 +27,16 @@ class JinaRerankModel(RerankModel):
|
|||||||
Model class for Jina rerank model.
|
Model class for Jina rerank model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
self,
|
||||||
user: Optional[str] = None) -> RerankResult:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
"""
|
"""
|
||||||
Invoke rerank model
|
Invoke rerank model
|
||||||
|
|
||||||
@ -45,26 +59,29 @@ class JinaRerankModel(RerankModel):
|
|||||||
"model": model,
|
"model": model,
|
||||||
"query": query,
|
"query": query,
|
||||||
"documents": docs,
|
"documents": docs,
|
||||||
"top_n": top_n
|
"top_n": top_n,
|
||||||
},
|
},
|
||||||
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
|
headers={"Authorization": f"Bearer {credentials.get('api_key')}"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
results = response.json()
|
results = response.json()
|
||||||
|
|
||||||
rerank_documents = []
|
rerank_documents = []
|
||||||
for result in results['results']:
|
for result in results["results"]:
|
||||||
rerank_document = RerankDocument(
|
rerank_document = RerankDocument(
|
||||||
index=result['index'],
|
index=result["index"],
|
||||||
text=result['document']['text'],
|
text=result["document"]["text"],
|
||||||
score=result['relevance_score'],
|
score=result["relevance_score"],
|
||||||
)
|
)
|
||||||
if score_threshold is None or result['relevance_score'] >= score_threshold:
|
if (
|
||||||
|
score_threshold is None
|
||||||
|
or result["relevance_score"] >= score_threshold
|
||||||
|
):
|
||||||
rerank_documents.append(rerank_document)
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
return RerankResult(model=model, docs=rerank_documents)
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise InvokeServerUnavailableError(str(e))
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -75,7 +92,6 @@ class JinaRerankModel(RerankModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
self._invoke(
|
self._invoke(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
@ -86,7 +102,7 @@ class JinaRerankModel(RerankModel):
|
|||||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
],
|
],
|
||||||
score_threshold=0.8
|
score_threshold=0.8,
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
@ -99,7 +115,7 @@ class JinaRerankModel(RerankModel):
|
|||||||
return {
|
return {
|
||||||
InvokeConnectionError: [httpx.ConnectError],
|
InvokeConnectionError: [httpx.ConnectError],
|
||||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||||
InvokeRateLimitError: [],
|
InvokeRateLimitError: [],
|
||||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||||
InvokeBadRequestError: [httpx.RequestError]
|
InvokeBadRequestError: [httpx.RequestError],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,19 +14,19 @@ class JinaTokenizer:
|
|||||||
with cls._lock:
|
with cls._lock:
|
||||||
if cls._tokenizer is None:
|
if cls._tokenizer is None:
|
||||||
base_path = abspath(__file__)
|
base_path = abspath(__file__)
|
||||||
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
gpt2_tokenizer_path = join(dirname(base_path), "tokenizer")
|
||||||
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||||
return cls._tokenizer
|
return cls._tokenizer
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
|
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
|
||||||
"""
|
"""
|
||||||
use jina tokenizer to get num tokens
|
use jina tokenizer to get num tokens
|
||||||
"""
|
"""
|
||||||
tokenizer = cls._get_tokenizer()
|
tokenizer = cls._get_tokenizer()
|
||||||
tokens = tokenizer.encode(text)
|
tokens = tokenizer.encode(text)
|
||||||
return len(tokens)
|
return len(tokens)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_num_tokens(cls, text: str) -> int:
|
def get_num_tokens(cls, text: str) -> int:
|
||||||
return cls._get_num_tokens_by_jina_base(text)
|
return cls._get_num_tokens_by_jina_base(text)
|
||||||
|
|||||||
@ -5,7 +5,10 @@ from typing import Optional
|
|||||||
from requests import post
|
from requests import post
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
EmbeddingUsage,
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -14,21 +17,37 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import JinaTokenizer
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import (
|
||||||
|
JinaTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JinaTextEmbeddingModel(TextEmbeddingModel):
|
class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
Model class for Jina text embedding model.
|
Model class for Jina text embedding model.
|
||||||
"""
|
"""
|
||||||
api_base: str = 'https://api.jina.ai/v1/embeddings'
|
|
||||||
models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de']
|
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
api_base: str = "https://api.jina.ai/v1/embeddings"
|
||||||
texts: list[str], user: Optional[str] = None) \
|
models: list[str] = [
|
||||||
-> TextEmbeddingResult:
|
"jina-embeddings-v2-base-en",
|
||||||
|
"jina-embeddings-v2-small-en",
|
||||||
|
"jina-embeddings-v2-base-zh",
|
||||||
|
"jina-embeddings-v2-base-de",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding model
|
Invoke text embedding model
|
||||||
|
|
||||||
@ -38,31 +57,28 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
api_key = credentials['api_key']
|
api_key = credentials["api_key"]
|
||||||
if model not in self.models:
|
if model not in self.models:
|
||||||
raise InvokeBadRequestError('Invalid model name')
|
raise InvokeBadRequestError("Invalid model name")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise CredentialsValidateFailedError('api_key is required')
|
raise CredentialsValidateFailedError("api_key is required")
|
||||||
url = self.api_base
|
url = self.api_base
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': 'Bearer ' + api_key,
|
"Authorization": "Bearer " + api_key,
|
||||||
'Content-Type': 'application/json'
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {
|
data = {"model": model, "input": texts}
|
||||||
'model': model,
|
|
||||||
'input': texts
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(url, headers=headers, data=dumps(data))
|
response = post(url, headers=headers, data=dumps(data))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InvokeConnectionError(str(e))
|
raise InvokeConnectionError(str(e))
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
msg = resp['detail']
|
msg = resp["detail"]
|
||||||
if response.status_code == 401:
|
if response.status_code == 401:
|
||||||
raise InvokeAuthorizationError(msg)
|
raise InvokeAuthorizationError(msg)
|
||||||
elif response.status_code == 429:
|
elif response.status_code == 429:
|
||||||
@ -72,23 +88,27 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
else:
|
else:
|
||||||
raise InvokeError(msg)
|
raise InvokeError(msg)
|
||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InvokeServerUnavailableError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
embeddings = resp['data']
|
embeddings = resp["data"]
|
||||||
usage = resp['usage']
|
usage = resp["usage"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InvokeServerUnavailableError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
|
usage = self._calc_response_usage(
|
||||||
|
model=model, credentials=credentials, tokens=usage["total_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
result = TextEmbeddingResult(
|
result = TextEmbeddingResult(
|
||||||
model=model,
|
model=model,
|
||||||
embeddings=[[
|
embeddings=[[float(data) for data in x["embedding"]] for x in embeddings],
|
||||||
float(data) for data in x['embedding']
|
usage=usage,
|
||||||
] for x in embeddings],
|
|
||||||
usage=usage
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -117,31 +137,23 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
raise CredentialsValidateFailedError('Invalid api key')
|
raise CredentialsValidateFailedError("Invalid api key")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [InvokeConnectionError],
|
||||||
InvokeConnectionError
|
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||||
],
|
InvokeRateLimitError: [InvokeRateLimitError],
|
||||||
InvokeServerUnavailableError: [
|
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||||
InvokeServerUnavailableError
|
InvokeBadRequestError: [KeyError],
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
InvokeRateLimitError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
|
||||||
InvokeAuthorizationError
|
|
||||||
],
|
|
||||||
InvokeBadRequestError: [
|
|
||||||
KeyError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
|
||||||
@ -155,7 +167,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -166,7 +178,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|||||||
@ -21,7 +21,12 @@ from openai.types.completion import Completion
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMMode,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -45,34 +50,60 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.utils import helper
|
from model_providers.core.model_runtime.utils import helper
|
||||||
|
|
||||||
|
|
||||||
class LocalAILarguageModel(LargeLanguageModel):
|
class LocalAILarguageModel(LargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
model: str,
|
||||||
stream: bool = True, user: str | None = None) \
|
credentials: dict,
|
||||||
-> LLMResult | Generator:
|
prompt_messages: list[PromptMessage],
|
||||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
model_parameters: dict,
|
||||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
|
return self._generate(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
model_parameters=model_parameters,
|
||||||
|
tools=tools,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
) -> int:
|
||||||
# tools is not supported yet
|
# tools is not supported yet
|
||||||
return self._num_tokens_from_messages(prompt_messages, tools=tools)
|
return self._num_tokens_from_messages(prompt_messages, tools=tools)
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int:
|
def _num_tokens_from_messages(
|
||||||
|
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate num tokens for baichuan model
|
Calculate num tokens for baichuan model
|
||||||
LocalAI does not supports
|
LocalAI does not supports
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
"""
|
"""
|
||||||
We cloud not determine which tokenizer to use, cause the model is customized.
|
We cloud not determine which tokenizer to use, cause the model is customized.
|
||||||
So we use gpt2 tokenizer to calculate the num tokens for convenience.
|
So we use gpt2 tokenizer to calculate the num tokens for convenience.
|
||||||
"""
|
"""
|
||||||
return self._get_num_tokens_by_gpt2(text)
|
return self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
@ -85,10 +116,10 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
text = ''
|
text = ""
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item['type'] == 'text':
|
if isinstance(item, dict) and item["type"] == "text":
|
||||||
text += item['text']
|
text += item["text"]
|
||||||
|
|
||||||
value = text
|
value = text
|
||||||
|
|
||||||
@ -124,7 +155,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
num_tokens += self._num_tokens_for_tools(tools)
|
num_tokens += self._num_tokens_for_tools(tools)
|
||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate num tokens for tool calling
|
Calculate num tokens for tool calling
|
||||||
@ -133,36 +164,37 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return: number of tokens
|
:return: number of tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokens(text: str):
|
def tokens(text: str):
|
||||||
return self._get_num_tokens_by_gpt2(text)
|
return self._get_num_tokens_by_gpt2(text)
|
||||||
|
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
# calculate num tokens for function object
|
# calculate num tokens for function object
|
||||||
num_tokens += tokens('name')
|
num_tokens += tokens("name")
|
||||||
num_tokens += tokens(tool.name)
|
num_tokens += tokens(tool.name)
|
||||||
num_tokens += tokens('description')
|
num_tokens += tokens("description")
|
||||||
num_tokens += tokens(tool.description)
|
num_tokens += tokens(tool.description)
|
||||||
parameters = tool.parameters
|
parameters = tool.parameters
|
||||||
num_tokens += tokens('parameters')
|
num_tokens += tokens("parameters")
|
||||||
num_tokens += tokens('type')
|
num_tokens += tokens("type")
|
||||||
num_tokens += tokens(parameters.get("type"))
|
num_tokens += tokens(parameters.get("type"))
|
||||||
if 'properties' in parameters:
|
if "properties" in parameters:
|
||||||
num_tokens += tokens('properties')
|
num_tokens += tokens("properties")
|
||||||
for key, value in parameters.get('properties').items():
|
for key, value in parameters.get("properties").items():
|
||||||
num_tokens += tokens(key)
|
num_tokens += tokens(key)
|
||||||
for field_key, field_value in value.items():
|
for field_key, field_value in value.items():
|
||||||
num_tokens += tokens(field_key)
|
num_tokens += tokens(field_key)
|
||||||
if field_key == 'enum':
|
if field_key == "enum":
|
||||||
for enum_field in field_value:
|
for enum_field in field_value:
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
num_tokens += tokens(enum_field)
|
num_tokens += tokens(enum_field)
|
||||||
else:
|
else:
|
||||||
num_tokens += tokens(field_key)
|
num_tokens += tokens(field_key)
|
||||||
num_tokens += tokens(str(field_value))
|
num_tokens += tokens(str(field_value))
|
||||||
if 'required' in parameters:
|
if "required" in parameters:
|
||||||
num_tokens += tokens('required')
|
num_tokens += tokens("required")
|
||||||
for required_field in parameters['required']:
|
for required_field in parameters["required"]:
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
num_tokens += tokens(required_field)
|
num_tokens += tokens(required_field)
|
||||||
|
|
||||||
@ -177,141 +209,166 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._invoke(model=model, credentials=credentials, prompt_messages=[
|
self._invoke(
|
||||||
UserPromptMessage(content='ping')
|
model=model,
|
||||||
], model_parameters={
|
credentials=credentials,
|
||||||
'max_tokens': 10,
|
prompt_messages=[UserPromptMessage(content="ping")],
|
||||||
}, stop=[], stream=False)
|
model_parameters={
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
stop=[],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')
|
raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}")
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> AIModelEntity | None:
|
||||||
completion_model = None
|
completion_model = None
|
||||||
if credentials['completion_type'] == 'chat_completion':
|
if credentials["completion_type"] == "chat_completion":
|
||||||
completion_model = LLMMode.CHAT.value
|
completion_model = LLMMode.CHAT.value
|
||||||
elif credentials['completion_type'] == 'completion':
|
elif credentials["completion_type"] == "completion":
|
||||||
completion_model = LLMMode.COMPLETION.value
|
completion_model = LLMMode.COMPLETION.value
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
raise ValueError(
|
||||||
|
f"Unknown completion type {credentials['completion_type']}"
|
||||||
|
)
|
||||||
|
|
||||||
rules = [
|
rules = [
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='temperature',
|
name="temperature",
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
use_template='temperature',
|
use_template="temperature",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
|
||||||
zh_Hans='温度',
|
|
||||||
en_US='Temperature'
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='top_p',
|
name="top_p",
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
use_template='top_p',
|
use_template="top_p",
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
|
||||||
zh_Hans='Top P',
|
|
||||||
en_US='Top P'
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='max_tokens',
|
name="max_tokens",
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
use_template='max_tokens',
|
use_template="max_tokens",
|
||||||
min=1,
|
min=1,
|
||||||
max=2048,
|
max=2048,
|
||||||
default=512,
|
default=512,
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
|
||||||
zh_Hans='最大生成长度',
|
),
|
||||||
en_US='Max Tokens'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
model_properties = {
|
model_properties = (
|
||||||
ModelPropertyKey.MODE: completion_model,
|
{
|
||||||
} if completion_model else {}
|
ModelPropertyKey.MODE: completion_model,
|
||||||
|
}
|
||||||
|
if completion_model
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))
|
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
||||||
|
credentials.get("context_size", "2048")
|
||||||
|
)
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(en_US=model),
|
||||||
en_US=model
|
|
||||||
),
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
model_properties=model_properties,
|
model_properties=model_properties,
|
||||||
parameter_rules=rules
|
parameter_rules=rules,
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _generate(
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
self,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
model: str,
|
||||||
-> LLMResult | Generator:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
kwargs = self._to_client_kwargs(credentials)
|
kwargs = self._to_client_kwargs(credentials)
|
||||||
# init model client
|
# init model client
|
||||||
client = OpenAI(**kwargs)
|
client = OpenAI(**kwargs)
|
||||||
|
|
||||||
model_name = model
|
model_name = model
|
||||||
completion_type = credentials['completion_type']
|
completion_type = credentials["completion_type"]
|
||||||
|
|
||||||
extra_model_kwargs = {
|
extra_model_kwargs = {
|
||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
}
|
}
|
||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop'] = stop
|
extra_model_kwargs["stop"] = stop
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
if tools and len(tools) > 0:
|
if tools and len(tools) > 0:
|
||||||
extra_model_kwargs['functions'] = [
|
extra_model_kwargs["functions"] = [
|
||||||
helper.dump_model(tool) for tool in tools
|
helper.dump_model(tool) for tool in tools
|
||||||
]
|
]
|
||||||
|
|
||||||
if completion_type == 'chat_completion':
|
if completion_type == "chat_completion":
|
||||||
result = client.chat.completions.create(
|
result = client.chat.completions.create(
|
||||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
messages=[
|
||||||
|
self._convert_prompt_message_to_dict(m) for m in prompt_messages
|
||||||
|
],
|
||||||
model=model_name,
|
model=model_name,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
**extra_model_kwargs,
|
**extra_model_kwargs,
|
||||||
)
|
)
|
||||||
elif completion_type == 'completion':
|
elif completion_type == "completion":
|
||||||
result = client.completions.create(
|
result = client.completions.create(
|
||||||
prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages),
|
prompt=self._convert_prompt_message_to_completion_prompts(
|
||||||
|
prompt_messages
|
||||||
|
),
|
||||||
model=model,
|
model=model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
**model_parameters,
|
**model_parameters,
|
||||||
**extra_model_kwargs
|
**extra_model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown completion type {completion_type}")
|
raise ValueError(f"Unknown completion type {completion_type}")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
if completion_type == 'completion':
|
if completion_type == "completion":
|
||||||
return self._handle_completion_generate_stream_response(
|
return self._handle_completion_generate_stream_response(
|
||||||
model=model, credentials=credentials, response=result, tools=tools,
|
model=model,
|
||||||
prompt_messages=prompt_messages
|
credentials=credentials,
|
||||||
|
response=result,
|
||||||
|
tools=tools,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
)
|
)
|
||||||
return self._handle_chat_generate_stream_response(
|
return self._handle_chat_generate_stream_response(
|
||||||
model=model, credentials=credentials, response=result, tools=tools,
|
model=model,
|
||||||
prompt_messages=prompt_messages
|
credentials=credentials,
|
||||||
|
response=result,
|
||||||
|
tools=tools,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
if completion_type == 'completion':
|
if completion_type == "completion":
|
||||||
return self._handle_completion_generate_response(
|
return self._handle_completion_generate_response(
|
||||||
model=model, credentials=credentials, response=result,
|
model=model,
|
||||||
prompt_messages=prompt_messages
|
credentials=credentials,
|
||||||
|
response=result,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
)
|
)
|
||||||
return self._handle_chat_generate_response(
|
return self._handle_chat_generate_response(
|
||||||
model=model, credentials=credentials, response=result, tools=tools,
|
model=model,
|
||||||
prompt_messages=prompt_messages
|
credentials=credentials,
|
||||||
|
response=result,
|
||||||
|
tools=tools,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _to_client_kwargs(self, credentials: dict) -> dict:
|
def _to_client_kwargs(self, credentials: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Convert invoke kwargs to client kwargs
|
Convert invoke kwargs to client kwargs
|
||||||
@ -319,13 +376,13 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
:param credentials: credentials dict
|
:param credentials: credentials dict
|
||||||
:return: client kwargs
|
:return: client kwargs
|
||||||
"""
|
"""
|
||||||
if not credentials['server_url'].endswith('/'):
|
if not credentials["server_url"].endswith("/"):
|
||||||
credentials['server_url'] += '/'
|
credentials["server_url"] += "/"
|
||||||
|
|
||||||
client_kwargs = {
|
client_kwargs = {
|
||||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||||
"api_key": "1",
|
"api_key": "1",
|
||||||
"base_url": str(URL(credentials['server_url']) / 'v1'),
|
"base_url": str(URL(credentials["server_url"]) / "v1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
return client_kwargs
|
return client_kwargs
|
||||||
@ -346,41 +403,45 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
if message.tool_calls and len(message.tool_calls) > 0:
|
if message.tool_calls and len(message.tool_calls) > 0:
|
||||||
message_dict["function_call"] = {
|
message_dict["function_call"] = {
|
||||||
"name": message.tool_calls[0].function.name,
|
"name": message.tool_calls[0].function.name,
|
||||||
"arguments": message.tool_calls[0].function.arguments
|
"arguments": message.tool_calls[0].function.arguments,
|
||||||
}
|
}
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
message_dict = {"role": "system", "content": message.content}
|
message_dict = {"role": "system", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMessage]) -> str:
|
def _convert_prompt_message_to_completion_prompts(
|
||||||
|
self, messages: list[PromptMessage]
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert PromptMessage to completion prompts
|
Convert PromptMessage to completion prompts
|
||||||
"""
|
"""
|
||||||
prompts = ''
|
prompts = ""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if isinstance(message, UserPromptMessage):
|
if isinstance(message, UserPromptMessage):
|
||||||
message = cast(UserPromptMessage, message)
|
message = cast(UserPromptMessage, message)
|
||||||
prompts += f'{message.content}\n'
|
prompts += f"{message.content}\n"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message = cast(AssistantPromptMessage, message)
|
message = cast(AssistantPromptMessage, message)
|
||||||
prompts += f'{message.content}\n'
|
prompts += f"{message.content}\n"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage):
|
||||||
message = cast(SystemPromptMessage, message)
|
message = cast(SystemPromptMessage, message)
|
||||||
prompts += f'{message.content}\n'
|
prompts += f"{message.content}\n"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type {type(message)}")
|
raise ValueError(f"Unknown message type {type(message)}")
|
||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def _handle_completion_generate_response(self, model: str,
|
def _handle_completion_generate_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
credentials: dict,
|
model: str,
|
||||||
response: Completion,
|
prompt_messages: list[PromptMessage],
|
||||||
) -> LLMResult:
|
credentials: dict,
|
||||||
|
response: Completion,
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm chat response
|
Handle llm chat response
|
||||||
|
|
||||||
@ -393,21 +454,27 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
if len(response.choices) == 0:
|
if len(response.choices) == 0:
|
||||||
raise InvokeServerUnavailableError("Empty response")
|
raise InvokeServerUnavailableError("Empty response")
|
||||||
|
|
||||||
assistant_message = response.choices[0].text
|
assistant_message = response.choices[0].text
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=assistant_message,
|
content=assistant_message, tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = self._get_num_tokens_by_gpt2(
|
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||||
self._convert_prompt_message_to_completion_prompts(prompt_messages)
|
self._convert_prompt_message_to_completion_prompts(prompt_messages)
|
||||||
)
|
)
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[])
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
messages=[assistant_prompt_message], tools=[]
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
@ -419,11 +486,14 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str,
|
def _handle_chat_generate_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
credentials: dict,
|
model: str,
|
||||||
response: ChatCompletion,
|
prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool]) -> LLMResult:
|
credentials: dict,
|
||||||
|
response: ChatCompletion,
|
||||||
|
tools: list[PromptMessageTool],
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm chat response
|
Handle llm chat response
|
||||||
|
|
||||||
@ -436,23 +506,33 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
if len(response.choices) == 0:
|
if len(response.choices) == 0:
|
||||||
raise InvokeServerUnavailableError("Empty response")
|
raise InvokeServerUnavailableError("Empty response")
|
||||||
|
|
||||||
assistant_message = response.choices[0].message
|
assistant_message = response.choices[0].message
|
||||||
|
|
||||||
# convert function call to tool call
|
# convert function call to tool call
|
||||||
function_calls = assistant_message.function_call
|
function_calls = assistant_message.function_call
|
||||||
tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else [])
|
tool_calls = self._extract_response_tool_calls(
|
||||||
|
[function_calls] if function_calls else []
|
||||||
|
)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=assistant_message.content,
|
content=assistant_message.content, tool_calls=tool_calls
|
||||||
tool_calls=tool_calls
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
prompt_tokens = self._num_tokens_from_messages(
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
messages=prompt_messages, tools=tools
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
messages=[assistant_prompt_message], tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
response = LLMResult(
|
response = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
@ -464,12 +544,15 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_completion_generate_stream_response(self, model: str,
|
def _handle_completion_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
credentials: dict,
|
model: str,
|
||||||
response: Stream[Completion],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool]) -> Generator:
|
credentials: dict,
|
||||||
full_response = ''
|
response: Stream[Completion],
|
||||||
|
tools: list[PromptMessageTool],
|
||||||
|
) -> Generator:
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
@ -479,26 +562,30 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.text if delta.text else '',
|
content=delta.text if delta.text else "", tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
# temp_assistant_prompt_message is used to calculate usage
|
# temp_assistant_prompt_message is used to calculate usage
|
||||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=full_response,
|
content=full_response, tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = self._get_num_tokens_by_gpt2(
|
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||||
self._convert_prompt_message_to_completion_prompts(prompt_messages)
|
self._convert_prompt_message_to_completion_prompts(prompt_messages)
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
messages=[temp_assistant_prompt_message], tools=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
||||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -507,7 +594,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
index=delta.index,
|
index=delta.index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=delta.finish_reason,
|
finish_reason=delta.finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -523,12 +610,15 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
full_response += delta.text
|
full_response += delta.text
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str,
|
def _handle_chat_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage],
|
self,
|
||||||
credentials: dict,
|
model: str,
|
||||||
response: Stream[ChatCompletionChunk],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: list[PromptMessageTool]) -> Generator:
|
credentials: dict,
|
||||||
full_response = ''
|
response: Stream[ChatCompletionChunk],
|
||||||
|
tools: list[PromptMessageTool],
|
||||||
|
) -> Generator:
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
@ -536,35 +626,46 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
delta = chunk.choices[0]
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
if delta.finish_reason is None and (
|
||||||
|
delta.delta.content is None or delta.delta.content == ""
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# check if there is a tool call in the response
|
# check if there is a tool call in the response
|
||||||
function_calls = None
|
function_calls = None
|
||||||
if delta.delta.function_call:
|
if delta.delta.function_call:
|
||||||
function_calls = [delta.delta.function_call]
|
function_calls = [delta.delta.function_call]
|
||||||
|
|
||||||
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
|
assistant_message_tool_calls = self._extract_response_tool_calls(
|
||||||
|
function_calls if function_calls else []
|
||||||
|
)
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta.delta.content if delta.delta.content else '',
|
content=delta.delta.content if delta.delta.content else "",
|
||||||
tool_calls=assistant_message_tool_calls
|
tool_calls=assistant_message_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
if delta.finish_reason is not None:
|
||||||
# temp_assistant_prompt_message is used to calculate usage
|
# temp_assistant_prompt_message is used to calculate usage
|
||||||
temp_assistant_prompt_message = AssistantPromptMessage(
|
temp_assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=full_response,
|
content=full_response, tool_calls=assistant_message_tool_calls
|
||||||
tool_calls=assistant_message_tool_calls
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
prompt_tokens = self._num_tokens_from_messages(
|
||||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
messages=prompt_messages, tools=tools
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_messages(
|
||||||
|
messages=[temp_assistant_prompt_message], tools=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
||||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -573,7 +674,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
index=delta.index,
|
index=delta.index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=delta.finish_reason,
|
finish_reason=delta.finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -589,9 +690,9 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
full_response += delta.delta.content
|
full_response += delta.delta.content
|
||||||
|
|
||||||
def _extract_response_tool_calls(self,
|
def _extract_response_tool_calls(
|
||||||
response_function_calls: list[FunctionCall]) \
|
self, response_function_calls: list[FunctionCall]
|
||||||
-> list[AssistantPromptMessage.ToolCall]:
|
) -> list[AssistantPromptMessage.ToolCall]:
|
||||||
"""
|
"""
|
||||||
Extract tool calls from response
|
Extract tool calls from response
|
||||||
|
|
||||||
@ -602,18 +703,15 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
if response_function_calls:
|
if response_function_calls:
|
||||||
for response_tool_call in response_function_calls:
|
for response_tool_call in response_function_calls:
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_tool_call.name,
|
name=response_tool_call.name, arguments=response_tool_call.arguments
|
||||||
arguments=response_tool_call.arguments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=0,
|
id=0, type="function", function=function
|
||||||
type='function',
|
|
||||||
function=function
|
|
||||||
)
|
)
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
@ -635,15 +733,9 @@ class LocalAILarguageModel(LargeLanguageModel):
|
|||||||
ConflictError,
|
ConflictError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnprocessableEntityError,
|
UnprocessableEntityError,
|
||||||
PermissionDeniedError
|
PermissionDeniedError,
|
||||||
],
|
],
|
||||||
InvokeRateLimitError: [
|
InvokeRateLimitError: [RateLimitError],
|
||||||
RateLimitError
|
InvokeAuthorizationError: [AuthenticationError],
|
||||||
],
|
InvokeBadRequestError: [ValueError],
|
||||||
InvokeAuthorizationError: [
|
|
||||||
AuthenticationError
|
|
||||||
],
|
|
||||||
InvokeBadRequestError: [
|
|
||||||
ValueError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LocalAIProvider(ModelProvider):
|
class LocalAIProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -6,8 +6,17 @@ from requests import post
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
PriceType,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
EmbeddingUsage,
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -16,17 +25,26 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
Model class for Jina text embedding model.
|
Model class for Jina text embedding model.
|
||||||
"""
|
"""
|
||||||
def _invoke(self, model: str, credentials: dict,
|
|
||||||
texts: list[str], user: Optional[str] = None) \
|
def _invoke(
|
||||||
-> TextEmbeddingResult:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding model
|
Invoke text embedding model
|
||||||
|
|
||||||
@ -37,39 +55,38 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
if len(texts) != 1:
|
if len(texts) != 1:
|
||||||
raise InvokeBadRequestError('Only one text is supported')
|
raise InvokeBadRequestError("Only one text is supported")
|
||||||
|
|
||||||
server_url = credentials['server_url']
|
server_url = credentials["server_url"]
|
||||||
model_name = model
|
model_name = model
|
||||||
if not server_url:
|
if not server_url:
|
||||||
raise CredentialsValidateFailedError('server_url is required')
|
raise CredentialsValidateFailedError("server_url is required")
|
||||||
if not model_name:
|
if not model_name:
|
||||||
raise CredentialsValidateFailedError('model_name is required')
|
raise CredentialsValidateFailedError("model_name is required")
|
||||||
|
|
||||||
url = server_url
|
|
||||||
headers = {
|
|
||||||
'Authorization': 'Bearer 123',
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
url = server_url
|
||||||
'model': model_name,
|
headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"}
|
||||||
'input': texts[0]
|
|
||||||
}
|
data = {"model": model_name, "input": texts[0]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
|
response = post(
|
||||||
|
str(URL(url) / "embeddings"),
|
||||||
|
headers=headers,
|
||||||
|
data=dumps(data),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InvokeConnectionError(str(e))
|
raise InvokeConnectionError(str(e))
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
code = resp['error']['code']
|
code = resp["error"]["code"]
|
||||||
msg = resp['error']['message']
|
msg = resp["error"]["message"]
|
||||||
if code == 500:
|
if code == 500:
|
||||||
raise InvokeServerUnavailableError(msg)
|
raise InvokeServerUnavailableError(msg)
|
||||||
|
|
||||||
if response.status_code == 401:
|
if response.status_code == 401:
|
||||||
raise InvokeAuthorizationError(msg)
|
raise InvokeAuthorizationError(msg)
|
||||||
elif response.status_code == 429:
|
elif response.status_code == 429:
|
||||||
@ -79,23 +96,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
else:
|
else:
|
||||||
raise InvokeError(msg)
|
raise InvokeError(msg)
|
||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InvokeServerUnavailableError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
embeddings = resp['data']
|
embeddings = resp["data"]
|
||||||
usage = resp['usage']
|
usage = resp["usage"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InvokeServerUnavailableError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
|
usage = self._calc_response_usage(
|
||||||
|
model=model, credentials=credentials, tokens=usage["total_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
result = TextEmbeddingResult(
|
result = TextEmbeddingResult(
|
||||||
model=model,
|
model=model,
|
||||||
embeddings=[[
|
embeddings=[[float(data) for data in x["embedding"]] for x in embeddings],
|
||||||
float(data) for data in x['embedding']
|
usage=usage,
|
||||||
] for x in embeddings],
|
|
||||||
usage=usage
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -114,8 +135,10 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
# use GPT2Tokenizer to get num tokens
|
# use GPT2Tokenizer to get num tokens
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(text)
|
num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
def _get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> AIModelEntity | None:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema
|
Get customizable model schema
|
||||||
|
|
||||||
@ -130,10 +153,12 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
features=[],
|
features=[],
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
|
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||||
|
credentials.get("context_size", "512")
|
||||||
|
),
|
||||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||||
},
|
},
|
||||||
parameter_rules=[]
|
parameter_rules=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
@ -145,33 +170,25 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
raise CredentialsValidateFailedError('Invalid credentials')
|
raise CredentialsValidateFailedError("Invalid credentials")
|
||||||
except InvokeConnectionError as e:
|
except InvokeConnectionError as e:
|
||||||
raise CredentialsValidateFailedError(f'Invalid credentials: {e}')
|
raise CredentialsValidateFailedError(f"Invalid credentials: {e}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [InvokeConnectionError],
|
||||||
InvokeConnectionError
|
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||||
],
|
InvokeRateLimitError: [InvokeRateLimitError],
|
||||||
InvokeServerUnavailableError: [
|
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||||
InvokeServerUnavailableError
|
InvokeBadRequestError: [KeyError],
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
InvokeRateLimitError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
|
||||||
InvokeAuthorizationError
|
|
||||||
],
|
|
||||||
InvokeBadRequestError: [
|
|
||||||
KeyError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
|
||||||
@ -185,7 +202,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -196,7 +213,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|||||||
@ -12,47 +12,61 @@ from model_providers.core.model_runtime.model_providers.minimax.llm.errors impor
|
|||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
RateLimitReachedError,
|
RateLimitReachedError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
from model_providers.core.model_runtime.model_providers.minimax.llm.types import (
|
||||||
|
MinimaxMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MinimaxChatCompletion:
|
class MinimaxChatCompletion:
|
||||||
"""
|
"""
|
||||||
Minimax Chat Completion API
|
Minimax Chat Completion API
|
||||||
"""
|
"""
|
||||||
def generate(self, model: str, api_key: str, group_id: str,
|
|
||||||
prompt_messages: list[MinimaxMessage], model_parameters: dict,
|
def generate(
|
||||||
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
|
self,
|
||||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
model: str,
|
||||||
|
api_key: str,
|
||||||
|
group_id: str,
|
||||||
|
prompt_messages: list[MinimaxMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[dict[str, Any]],
|
||||||
|
stop: list[str] | None,
|
||||||
|
stream: bool,
|
||||||
|
user: str,
|
||||||
|
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||||
"""
|
"""
|
||||||
generate chat completion
|
generate chat completion
|
||||||
"""
|
"""
|
||||||
if not api_key or not group_id:
|
if not api_key or not group_id:
|
||||||
raise InvalidAPIKeyError('Invalid API key or group ID')
|
raise InvalidAPIKeyError("Invalid API key or group ID")
|
||||||
|
|
||||||
url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
|
url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}"
|
||||||
|
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
|
|
||||||
if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int:
|
if (
|
||||||
extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens']
|
"max_tokens" in model_parameters
|
||||||
|
and type(model_parameters["max_tokens"]) == int
|
||||||
|
):
|
||||||
|
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
|
||||||
|
|
||||||
if 'temperature' in model_parameters and type(model_parameters['temperature']) == float:
|
if (
|
||||||
extra_kwargs['temperature'] = model_parameters['temperature']
|
"temperature" in model_parameters
|
||||||
|
and type(model_parameters["temperature"]) == float
|
||||||
|
):
|
||||||
|
extra_kwargs["temperature"] = model_parameters["temperature"]
|
||||||
|
|
||||||
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
|
if "top_p" in model_parameters and type(model_parameters["top_p"]) == float:
|
||||||
extra_kwargs['top_p'] = model_parameters['top_p']
|
extra_kwargs["top_p"] = model_parameters["top_p"]
|
||||||
|
|
||||||
prompt = '你是一个什么都懂的专家'
|
prompt = "你是一个什么都懂的专家"
|
||||||
|
|
||||||
role_meta = {
|
role_meta = {"user_name": "我", "bot_name": "专家"}
|
||||||
'user_name': '我',
|
|
||||||
'bot_name': '专家'
|
|
||||||
}
|
|
||||||
|
|
||||||
# check if there is a system message
|
# check if there is a system message
|
||||||
if len(prompt_messages) == 0:
|
if len(prompt_messages) == 0:
|
||||||
raise BadRequestError('At least one message is required')
|
raise BadRequestError("At least one message is required")
|
||||||
|
|
||||||
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
||||||
if prompt_messages[0].content:
|
if prompt_messages[0].content:
|
||||||
prompt = prompt_messages[0].content
|
prompt = prompt_messages[0].content
|
||||||
@ -60,40 +74,48 @@ class MinimaxChatCompletion:
|
|||||||
|
|
||||||
# check if there is a user message
|
# check if there is a user message
|
||||||
if len(prompt_messages) == 0:
|
if len(prompt_messages) == 0:
|
||||||
raise BadRequestError('At least one user message is required')
|
raise BadRequestError("At least one user message is required")
|
||||||
|
|
||||||
messages = [{
|
messages = [
|
||||||
'sender_type': message.role,
|
{
|
||||||
'text': message.content,
|
"sender_type": message.role,
|
||||||
} for message in prompt_messages]
|
"text": message.content,
|
||||||
|
}
|
||||||
|
for message in prompt_messages
|
||||||
|
]
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': 'Bearer ' + api_key,
|
"Authorization": "Bearer " + api_key,
|
||||||
'Content-Type': 'application/json'
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
'model': model,
|
"model": model,
|
||||||
'messages': messages,
|
"messages": messages,
|
||||||
'prompt': prompt,
|
"prompt": prompt,
|
||||||
'role_meta': role_meta,
|
"role_meta": role_meta,
|
||||||
'stream': stream,
|
"stream": stream,
|
||||||
**extra_kwargs
|
**extra_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(
|
response = post(
|
||||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
url=url,
|
||||||
|
data=dumps(body),
|
||||||
|
headers=headers,
|
||||||
|
stream=stream,
|
||||||
|
timeout=(10, 300),
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(e)
|
raise InternalServerError(e)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise InternalServerError(response.text)
|
raise InternalServerError(response.text)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_stream_chat_generate_response(response)
|
return self._handle_stream_chat_generate_response(response)
|
||||||
return self._handle_chat_generate_response(response)
|
return self._handle_chat_generate_response(response)
|
||||||
|
|
||||||
def _handle_error(self, code: int, msg: str):
|
def _handle_error(self, code: int, msg: str):
|
||||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
@ -110,65 +132,64 @@ class MinimaxChatCompletion:
|
|||||||
|
|
||||||
def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
|
def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
|
||||||
"""
|
"""
|
||||||
handle chat generate response
|
handle chat generate response
|
||||||
"""
|
"""
|
||||||
response = response.json()
|
response = response.json()
|
||||||
if 'base_resp' in response and response['base_resp']['status_code'] != 0:
|
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
|
||||||
code = response['base_resp']['status_code']
|
code = response["base_resp"]["status_code"]
|
||||||
msg = response['base_resp']['status_msg']
|
msg = response["base_resp"]["status_msg"]
|
||||||
self._handle_error(code, msg)
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
message = MinimaxMessage(
|
message = MinimaxMessage(
|
||||||
content=response['reply'],
|
content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value
|
|
||||||
)
|
)
|
||||||
message.usage = {
|
message.usage = {
|
||||||
'prompt_tokens': 0,
|
"prompt_tokens": 0,
|
||||||
'completion_tokens': response['usage']['total_tokens'],
|
"completion_tokens": response["usage"]["total_tokens"],
|
||||||
'total_tokens': response['usage']['total_tokens']
|
"total_tokens": response["usage"]["total_tokens"],
|
||||||
}
|
}
|
||||||
message.stop_reason = response['choices'][0]['finish_reason']
|
message.stop_reason = response["choices"][0]["finish_reason"]
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
|
def _handle_stream_chat_generate_response(
|
||||||
|
self, response: Response
|
||||||
|
) -> Generator[MinimaxMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
handle stream chat generate response
|
handle stream chat generate response
|
||||||
"""
|
"""
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
line: str = line.decode('utf-8')
|
line: str = line.decode("utf-8")
|
||||||
if line.startswith('data: '):
|
if line.startswith("data: "):
|
||||||
line = line[6:].strip()
|
line = line[6:].strip()
|
||||||
data = loads(line)
|
data = loads(line)
|
||||||
|
|
||||||
if 'base_resp' in data and data['base_resp']['status_code'] != 0:
|
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
|
||||||
code = data['base_resp']['status_code']
|
code = data["base_resp"]["status_code"]
|
||||||
msg = data['base_resp']['status_msg']
|
msg = data["base_resp"]["status_msg"]
|
||||||
self._handle_error(code, msg)
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
if data['reply']:
|
if data["reply"]:
|
||||||
total_tokens = data['usage']['total_tokens']
|
total_tokens = data["usage"]["total_tokens"]
|
||||||
message = MinimaxMessage(
|
message = MinimaxMessage(
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
role=MinimaxMessage.Role.ASSISTANT.value, content=""
|
||||||
content=''
|
|
||||||
)
|
)
|
||||||
message.usage = {
|
message.usage = {
|
||||||
'prompt_tokens': 0,
|
"prompt_tokens": 0,
|
||||||
'completion_tokens': total_tokens,
|
"completion_tokens": total_tokens,
|
||||||
'total_tokens': total_tokens
|
"total_tokens": total_tokens,
|
||||||
}
|
}
|
||||||
message.stop_reason = data['choices'][0]['finish_reason']
|
message.stop_reason = data["choices"][0]["finish_reason"]
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
choices = data.get('choices', [])
|
choices = data.get("choices", [])
|
||||||
if len(choices) == 0:
|
if len(choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
message = choice['delta']
|
message = choice["delta"]
|
||||||
yield MinimaxMessage(
|
yield MinimaxMessage(
|
||||||
content=message,
|
content=message, role=MinimaxMessage.Role.ASSISTANT.value
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -12,98 +12,115 @@ from model_providers.core.model_runtime.model_providers.minimax.llm.errors impor
|
|||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
RateLimitReachedError,
|
RateLimitReachedError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
from model_providers.core.model_runtime.model_providers.minimax.llm.types import (
|
||||||
|
MinimaxMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MinimaxChatCompletionPro:
|
class MinimaxChatCompletionPro:
|
||||||
"""
|
"""
|
||||||
Minimax Chat Completion Pro API, supports function calling
|
Minimax Chat Completion Pro API, supports function calling
|
||||||
however, we do not have enough time and energy to implement it, but the parameters are reserved
|
however, we do not have enough time and energy to implement it, but the parameters are reserved
|
||||||
"""
|
"""
|
||||||
def generate(self, model: str, api_key: str, group_id: str,
|
|
||||||
prompt_messages: list[MinimaxMessage], model_parameters: dict,
|
def generate(
|
||||||
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
|
self,
|
||||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
model: str,
|
||||||
|
api_key: str,
|
||||||
|
group_id: str,
|
||||||
|
prompt_messages: list[MinimaxMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[dict[str, Any]],
|
||||||
|
stop: list[str] | None,
|
||||||
|
stream: bool,
|
||||||
|
user: str,
|
||||||
|
) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||||
"""
|
"""
|
||||||
generate chat completion
|
generate chat completion
|
||||||
"""
|
"""
|
||||||
if not api_key or not group_id:
|
if not api_key or not group_id:
|
||||||
raise InvalidAPIKeyError('Invalid API key or group ID')
|
raise InvalidAPIKeyError("Invalid API key or group ID")
|
||||||
|
|
||||||
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
|
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
|
||||||
|
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
|
|
||||||
if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int:
|
if (
|
||||||
extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens']
|
"max_tokens" in model_parameters
|
||||||
|
and type(model_parameters["max_tokens"]) == int
|
||||||
|
):
|
||||||
|
extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"]
|
||||||
|
|
||||||
if 'temperature' in model_parameters and type(model_parameters['temperature']) == float:
|
if (
|
||||||
extra_kwargs['temperature'] = model_parameters['temperature']
|
"temperature" in model_parameters
|
||||||
|
and type(model_parameters["temperature"]) == float
|
||||||
|
):
|
||||||
|
extra_kwargs["temperature"] = model_parameters["temperature"]
|
||||||
|
|
||||||
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
|
if "top_p" in model_parameters and type(model_parameters["top_p"]) == float:
|
||||||
extra_kwargs['top_p'] = model_parameters['top_p']
|
extra_kwargs["top_p"] = model_parameters["top_p"]
|
||||||
|
|
||||||
if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
|
|
||||||
extra_kwargs['plugins'] = [
|
|
||||||
'plugin_web_search'
|
|
||||||
]
|
|
||||||
|
|
||||||
bot_setting = {
|
if (
|
||||||
'bot_name': '专家',
|
"plugin_web_search" in model_parameters
|
||||||
'content': '你是一个什么都懂的专家'
|
and model_parameters["plugin_web_search"]
|
||||||
}
|
):
|
||||||
|
extra_kwargs["plugins"] = ["plugin_web_search"]
|
||||||
|
|
||||||
reply_constraints = {
|
bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"}
|
||||||
'sender_type': 'BOT',
|
|
||||||
'sender_name': '专家'
|
reply_constraints = {"sender_type": "BOT", "sender_name": "专家"}
|
||||||
}
|
|
||||||
|
|
||||||
# check if there is a system message
|
# check if there is a system message
|
||||||
if len(prompt_messages) == 0:
|
if len(prompt_messages) == 0:
|
||||||
raise BadRequestError('At least one message is required')
|
raise BadRequestError("At least one message is required")
|
||||||
|
|
||||||
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
||||||
if prompt_messages[0].content:
|
if prompt_messages[0].content:
|
||||||
bot_setting['content'] = prompt_messages[0].content
|
bot_setting["content"] = prompt_messages[0].content
|
||||||
prompt_messages = prompt_messages[1:]
|
prompt_messages = prompt_messages[1:]
|
||||||
|
|
||||||
# check if there is a user message
|
# check if there is a user message
|
||||||
if len(prompt_messages) == 0:
|
if len(prompt_messages) == 0:
|
||||||
raise BadRequestError('At least one user message is required')
|
raise BadRequestError("At least one user message is required")
|
||||||
|
|
||||||
messages = [message.to_dict() for message in prompt_messages]
|
messages = [message.to_dict() for message in prompt_messages]
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': 'Bearer ' + api_key,
|
"Authorization": "Bearer " + api_key,
|
||||||
'Content-Type': 'application/json'
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
'model': model,
|
"model": model,
|
||||||
'messages': messages,
|
"messages": messages,
|
||||||
'bot_setting': [bot_setting],
|
"bot_setting": [bot_setting],
|
||||||
'reply_constraints': reply_constraints,
|
"reply_constraints": reply_constraints,
|
||||||
'stream': stream,
|
"stream": stream,
|
||||||
**extra_kwargs
|
**extra_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
body['functions'] = tools
|
body["functions"] = tools
|
||||||
body['function_call'] = { 'type': 'auto' }
|
body["function_call"] = {"type": "auto"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(
|
response = post(
|
||||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
url=url,
|
||||||
|
data=dumps(body),
|
||||||
|
headers=headers,
|
||||||
|
stream=stream,
|
||||||
|
timeout=(10, 300),
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(e)
|
raise InternalServerError(e)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise InternalServerError(response.text)
|
raise InternalServerError(response.text)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_stream_chat_generate_response(response)
|
return self._handle_stream_chat_generate_response(response)
|
||||||
return self._handle_chat_generate_response(response)
|
return self._handle_chat_generate_response(response)
|
||||||
|
|
||||||
def _handle_error(self, code: int, msg: str):
|
def _handle_error(self, code: int, msg: str):
|
||||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
@ -120,92 +137,101 @@ class MinimaxChatCompletionPro:
|
|||||||
|
|
||||||
def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
|
def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage:
|
||||||
"""
|
"""
|
||||||
handle chat generate response
|
handle chat generate response
|
||||||
"""
|
"""
|
||||||
response = response.json()
|
response = response.json()
|
||||||
if 'base_resp' in response and response['base_resp']['status_code'] != 0:
|
if "base_resp" in response and response["base_resp"]["status_code"] != 0:
|
||||||
code = response['base_resp']['status_code']
|
code = response["base_resp"]["status_code"]
|
||||||
msg = response['base_resp']['status_msg']
|
msg = response["base_resp"]["status_msg"]
|
||||||
self._handle_error(code, msg)
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
message = MinimaxMessage(
|
message = MinimaxMessage(
|
||||||
content=response['reply'],
|
content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value
|
|
||||||
)
|
)
|
||||||
message.usage = {
|
message.usage = {
|
||||||
'prompt_tokens': 0,
|
"prompt_tokens": 0,
|
||||||
'completion_tokens': response['usage']['total_tokens'],
|
"completion_tokens": response["usage"]["total_tokens"],
|
||||||
'total_tokens': response['usage']['total_tokens']
|
"total_tokens": response["usage"]["total_tokens"],
|
||||||
}
|
}
|
||||||
message.stop_reason = response['choices'][0]['finish_reason']
|
message.stop_reason = response["choices"][0]["finish_reason"]
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]:
|
def _handle_stream_chat_generate_response(
|
||||||
|
self, response: Response
|
||||||
|
) -> Generator[MinimaxMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
handle stream chat generate response
|
handle stream chat generate response
|
||||||
"""
|
"""
|
||||||
function_call_storage = None
|
function_call_storage = None
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
line: str = line.decode('utf-8')
|
line: str = line.decode("utf-8")
|
||||||
if line.startswith('data: '):
|
if line.startswith("data: "):
|
||||||
line = line[6:].strip()
|
line = line[6:].strip()
|
||||||
data = loads(line)
|
data = loads(line)
|
||||||
|
|
||||||
if 'base_resp' in data and data['base_resp']['status_code'] != 0:
|
if "base_resp" in data and data["base_resp"]["status_code"] != 0:
|
||||||
code = data['base_resp']['status_code']
|
code = data["base_resp"]["status_code"]
|
||||||
msg = data['base_resp']['status_msg']
|
msg = data["base_resp"]["status_msg"]
|
||||||
self._handle_error(code, msg)
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
if data['reply'] or 'usage' in data and data['usage']:
|
if data["reply"] or "usage" in data and data["usage"]:
|
||||||
total_tokens = data['usage']['total_tokens']
|
total_tokens = data["usage"]["total_tokens"]
|
||||||
message = MinimaxMessage(
|
message = MinimaxMessage(
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
role=MinimaxMessage.Role.ASSISTANT.value, content=""
|
||||||
content=''
|
|
||||||
)
|
)
|
||||||
message.usage = {
|
message.usage = {
|
||||||
'prompt_tokens': 0,
|
"prompt_tokens": 0,
|
||||||
'completion_tokens': total_tokens,
|
"completion_tokens": total_tokens,
|
||||||
'total_tokens': total_tokens
|
"total_tokens": total_tokens,
|
||||||
}
|
}
|
||||||
message.stop_reason = data['choices'][0]['finish_reason']
|
message.stop_reason = data["choices"][0]["finish_reason"]
|
||||||
|
|
||||||
if function_call_storage:
|
if function_call_storage:
|
||||||
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
function_call_message = MinimaxMessage(
|
||||||
|
content="", role=MinimaxMessage.Role.ASSISTANT.value
|
||||||
|
)
|
||||||
function_call_message.function_call = function_call_storage
|
function_call_message.function_call = function_call_storage
|
||||||
yield function_call_message
|
yield function_call_message
|
||||||
|
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
choices = data.get('choices', [])
|
choices = data.get("choices", [])
|
||||||
if len(choices) == 0:
|
if len(choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
message = choice['messages'][0]
|
message = choice["messages"][0]
|
||||||
|
|
||||||
if 'function_call' in message:
|
if "function_call" in message:
|
||||||
if not function_call_storage:
|
if not function_call_storage:
|
||||||
function_call_storage = message['function_call']
|
function_call_storage = message["function_call"]
|
||||||
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
|
if (
|
||||||
function_call_storage['arguments'] = ''
|
"arguments" not in function_call_storage
|
||||||
|
or not function_call_storage["arguments"]
|
||||||
|
):
|
||||||
|
function_call_storage["arguments"] = ""
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
function_call_storage['arguments'] += message['function_call']['arguments']
|
function_call_storage["arguments"] += message["function_call"][
|
||||||
|
"arguments"
|
||||||
|
]
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if function_call_storage:
|
if function_call_storage:
|
||||||
message['function_call'] = function_call_storage
|
message["function_call"] = function_call_storage
|
||||||
function_call_storage = None
|
function_call_storage = None
|
||||||
|
|
||||||
minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
|
||||||
|
|
||||||
if 'function_call' in message:
|
minimax_message = MinimaxMessage(
|
||||||
minimax_message.function_call = message['function_call']
|
content="", role=MinimaxMessage.Role.ASSISTANT.value
|
||||||
|
)
|
||||||
|
|
||||||
if 'text' in message:
|
if "function_call" in message:
|
||||||
minimax_message.content = message['text']
|
minimax_message.function_call = message["function_call"]
|
||||||
|
|
||||||
|
if "text" in message:
|
||||||
|
minimax_message.content = message["text"]
|
||||||
|
|
||||||
yield minimax_message
|
yield minimax_message
|
||||||
|
|||||||
@ -1,17 +1,22 @@
|
|||||||
class InvalidAuthenticationError(Exception):
|
class InvalidAuthenticationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidAPIKeyError(Exception):
|
class InvalidAPIKeyError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RateLimitReachedError(Exception):
|
class RateLimitReachedError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InsufficientAccountBalanceError(Exception):
|
class InsufficientAccountBalanceError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InternalServerError(Exception):
|
class InternalServerError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BadRequestError(Exception):
|
class BadRequestError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -17,10 +21,18 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.minimax.llm.chat_completion import MinimaxChatCompletion
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.minimax.llm.chat_completion_pro import MinimaxChatCompletionPro
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.minimax.llm.chat_completion import (
|
||||||
|
MinimaxChatCompletion,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.minimax.llm.chat_completion_pro import (
|
||||||
|
MinimaxChatCompletionPro,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.minimax.llm.errors import (
|
from model_providers.core.model_runtime.model_providers.minimax.llm.errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalanceError,
|
InsufficientAccountBalanceError,
|
||||||
@ -29,131 +41,202 @@ from model_providers.core.model_runtime.model_providers.minimax.llm.errors impor
|
|||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
RateLimitReachedError,
|
RateLimitReachedError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
from model_providers.core.model_runtime.model_providers.minimax.llm.types import (
|
||||||
|
MinimaxMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||||
model_apis = {
|
model_apis = {
|
||||||
'abab6-chat': MinimaxChatCompletionPro,
|
"abab6-chat": MinimaxChatCompletionPro,
|
||||||
'abab5.5s-chat': MinimaxChatCompletionPro,
|
"abab5.5s-chat": MinimaxChatCompletionPro,
|
||||||
'abab5.5-chat': MinimaxChatCompletionPro,
|
"abab5.5-chat": MinimaxChatCompletionPro,
|
||||||
'abab5-chat': MinimaxChatCompletion
|
"abab5-chat": MinimaxChatCompletion,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _invoke(
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
self,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
model: str,
|
||||||
-> LLMResult | Generator:
|
credentials: dict,
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
|
return self._generate(
|
||||||
|
model,
|
||||||
|
credentials,
|
||||||
|
prompt_messages,
|
||||||
|
model_parameters,
|
||||||
|
tools,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate credentials for Baichuan model
|
Validate credentials for Baichuan model
|
||||||
"""
|
"""
|
||||||
if model not in self.model_apis:
|
if model not in self.model_apis:
|
||||||
raise CredentialsValidateFailedError(f'Invalid model: {model}')
|
raise CredentialsValidateFailedError(f"Invalid model: {model}")
|
||||||
|
|
||||||
if not credentials.get('minimax_api_key'):
|
if not credentials.get("minimax_api_key"):
|
||||||
raise CredentialsValidateFailedError('Invalid API key')
|
raise CredentialsValidateFailedError("Invalid API key")
|
||||||
|
|
||||||
|
if not credentials.get("minimax_group_id"):
|
||||||
|
raise CredentialsValidateFailedError("Invalid group ID")
|
||||||
|
|
||||||
if not credentials.get('minimax_group_id'):
|
|
||||||
raise CredentialsValidateFailedError('Invalid group ID')
|
|
||||||
|
|
||||||
# ping
|
# ping
|
||||||
instance = MinimaxChatCompletionPro()
|
instance = MinimaxChatCompletionPro()
|
||||||
try:
|
try:
|
||||||
instance.generate(
|
instance.generate(
|
||||||
model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'],
|
model=model,
|
||||||
prompt_messages=[
|
api_key=credentials["minimax_api_key"],
|
||||||
MinimaxMessage(content='ping', role='USER')
|
group_id=credentials["minimax_group_id"],
|
||||||
],
|
prompt_messages=[MinimaxMessage(content="ping", role="USER")],
|
||||||
model_parameters={},
|
model_parameters={},
|
||||||
tools=[], stop=[],
|
tools=[],
|
||||||
|
stop=[],
|
||||||
stream=False,
|
stream=False,
|
||||||
user=''
|
user="",
|
||||||
)
|
)
|
||||||
except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e:
|
except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e:
|
||||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: list[PromptMessageTool] | None = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
) -> int:
|
||||||
return self._num_tokens_from_messages(prompt_messages, tools)
|
return self._num_tokens_from_messages(prompt_messages, tools)
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int:
|
def _num_tokens_from_messages(
|
||||||
|
self, messages: list[PromptMessage], tools: list[PromptMessageTool]
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate num tokens for minimax model
|
Calculate num tokens for minimax model
|
||||||
|
|
||||||
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
|
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
|
||||||
to caculate the num tokens, so we use str() to convert the prompt to string
|
to calculate the num tokens, so we use str() to convert the prompt to string
|
||||||
|
|
||||||
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
|
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
|
||||||
therefore, we use gpt2 tokenizer instead
|
therefore, we use gpt2 tokenizer instead
|
||||||
"""
|
"""
|
||||||
messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages]
|
messages_dict = [
|
||||||
|
self._convert_prompt_message_to_minimax_message(m).to_dict()
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
return self._get_num_tokens_by_gpt2(str(messages_dict))
|
return self._get_num_tokens_by_gpt2(str(messages_dict))
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def _generate(
|
||||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
self,
|
||||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
model: str,
|
||||||
-> LLMResult | Generator:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: list[PromptMessageTool] | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> LLMResult | Generator:
|
||||||
"""
|
"""
|
||||||
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
|
use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface
|
||||||
"""
|
"""
|
||||||
client: MinimaxChatCompletionPro = self.model_apis[model]()
|
client: MinimaxChatCompletionPro = self.model_apis[model]()
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
tools = [{
|
tools = [
|
||||||
"name": tool.name,
|
{
|
||||||
"description": tool.description,
|
"name": tool.name,
|
||||||
"parameters": tool.parameters
|
"description": tool.description,
|
||||||
} for tool in tools]
|
"parameters": tool.parameters,
|
||||||
|
}
|
||||||
|
for tool in tools
|
||||||
|
]
|
||||||
|
|
||||||
response = client.generate(
|
response = client.generate(
|
||||||
model=model,
|
model=model,
|
||||||
api_key=credentials['minimax_api_key'],
|
api_key=credentials["minimax_api_key"],
|
||||||
group_id=credentials['minimax_group_id'],
|
group_id=credentials["minimax_group_id"],
|
||||||
prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages],
|
prompt_messages=[
|
||||||
|
self._convert_prompt_message_to_minimax_message(message)
|
||||||
|
for message in prompt_messages
|
||||||
|
],
|
||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response)
|
return self._handle_chat_generate_stream_response(
|
||||||
return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response)
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
credentials=credentials,
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
return self._handle_chat_generate_response(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
credentials=credentials,
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage:
|
def _convert_prompt_message_to_minimax_message(
|
||||||
|
self, prompt_message: PromptMessage
|
||||||
|
) -> MinimaxMessage:
|
||||||
"""
|
"""
|
||||||
convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface
|
convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface
|
||||||
"""
|
"""
|
||||||
if isinstance(prompt_message, SystemPromptMessage):
|
if isinstance(prompt_message, SystemPromptMessage):
|
||||||
return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content)
|
return MinimaxMessage(
|
||||||
|
role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content
|
||||||
|
)
|
||||||
elif isinstance(prompt_message, UserPromptMessage):
|
elif isinstance(prompt_message, UserPromptMessage):
|
||||||
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
|
return MinimaxMessage(
|
||||||
|
role=MinimaxMessage.Role.USER.value, content=prompt_message.content
|
||||||
|
)
|
||||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||||
if prompt_message.tool_calls:
|
if prompt_message.tool_calls:
|
||||||
message = MinimaxMessage(
|
message = MinimaxMessage(
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
role=MinimaxMessage.Role.ASSISTANT.value, content=""
|
||||||
content=''
|
|
||||||
)
|
)
|
||||||
message.function_call={
|
message.function_call = {
|
||||||
'name': prompt_message.tool_calls[0].function.name,
|
"name": prompt_message.tool_calls[0].function.name,
|
||||||
'arguments': prompt_message.tool_calls[0].function.arguments
|
"arguments": prompt_message.tool_calls[0].function.arguments,
|
||||||
}
|
}
|
||||||
return message
|
return message
|
||||||
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
|
return MinimaxMessage(
|
||||||
|
role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content
|
||||||
|
)
|
||||||
elif isinstance(prompt_message, ToolPromptMessage):
|
elif isinstance(prompt_message, ToolPromptMessage):
|
||||||
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
|
return MinimaxMessage(
|
||||||
|
role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
|
raise NotImplementedError(
|
||||||
|
f"Prompt message type {type(prompt_message)} is not supported"
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult:
|
def _handle_chat_generate_response(
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
self,
|
||||||
prompt_tokens=response.usage['prompt_tokens'],
|
model: str,
|
||||||
completion_tokens=response.usage['completion_tokens']
|
prompt_messages: list[PromptMessage],
|
||||||
)
|
credentials: dict,
|
||||||
|
response: MinimaxMessage,
|
||||||
|
) -> LLMResult:
|
||||||
|
usage = self._calc_response_usage(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_tokens=response.usage["prompt_tokens"],
|
||||||
|
completion_tokens=response.usage["completion_tokens"],
|
||||||
|
)
|
||||||
return LLMResult(
|
return LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -164,15 +247,20 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage],
|
def _handle_chat_generate_stream_response(
|
||||||
credentials: dict, response: Generator[MinimaxMessage, None, None]) \
|
self,
|
||||||
-> Generator[LLMResultChunk, None, None]:
|
model: str,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
credentials: dict,
|
||||||
|
response: Generator[MinimaxMessage, None, None],
|
||||||
|
) -> Generator[LLMResultChunk, None, None]:
|
||||||
for message in response:
|
for message in response:
|
||||||
if message.usage:
|
if message.usage:
|
||||||
usage = self._calc_response_usage(
|
usage = self._calc_response_usage(
|
||||||
model=model, credentials=credentials,
|
model=model,
|
||||||
prompt_tokens=message.usage['prompt_tokens'],
|
credentials=credentials,
|
||||||
completion_tokens=message.usage['completion_tokens']
|
prompt_tokens=message.usage["prompt_tokens"],
|
||||||
|
completion_tokens=message.usage["completion_tokens"],
|
||||||
)
|
)
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
@ -180,15 +268,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=message.content,
|
content=message.content, tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
),
|
),
|
||||||
usage=usage,
|
usage=usage,
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason
|
||||||
|
if message.stop_reason
|
||||||
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif message.function_call:
|
elif message.function_call:
|
||||||
if 'name' not in message.function_call or 'arguments' not in message.function_call:
|
if (
|
||||||
|
"name" not in message.function_call
|
||||||
|
or "arguments" not in message.function_call
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@ -197,15 +289,17 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content='',
|
content="",
|
||||||
tool_calls=[AssistantPromptMessage.ToolCall(
|
tool_calls=[
|
||||||
id='',
|
AssistantPromptMessage.ToolCall(
|
||||||
type='function',
|
id="",
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
type="function",
|
||||||
name=message.function_call['name'],
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
arguments=message.function_call['arguments']
|
name=message.function_call["name"],
|
||||||
|
arguments=message.function_call["arguments"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)]
|
],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -216,10 +310,11 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(
|
message=AssistantPromptMessage(
|
||||||
content=message.content,
|
content=message.content, tool_calls=[]
|
||||||
tool_calls=[]
|
|
||||||
),
|
),
|
||||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
finish_reason=message.stop_reason
|
||||||
|
if message.stop_reason
|
||||||
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -234,22 +329,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [],
|
||||||
],
|
InvokeServerUnavailableError: [InternalServerError],
|
||||||
InvokeServerUnavailableError: [
|
InvokeRateLimitError: [RateLimitReachedError],
|
||||||
InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
RateLimitReachedError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
InsufficientAccountBalanceError,
|
InsufficientAccountBalanceError,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||||
BadRequestError,
|
|
||||||
KeyError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,32 +4,32 @@ from typing import Any
|
|||||||
|
|
||||||
class MinimaxMessage:
|
class MinimaxMessage:
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
USER = 'USER'
|
USER = "USER"
|
||||||
ASSISTANT = 'BOT'
|
ASSISTANT = "BOT"
|
||||||
SYSTEM = 'SYSTEM'
|
SYSTEM = "SYSTEM"
|
||||||
FUNCTION = 'FUNCTION'
|
FUNCTION = "FUNCTION"
|
||||||
|
|
||||||
role: str = Role.USER.value
|
role: str = Role.USER.value
|
||||||
content: str
|
content: str
|
||||||
usage: dict[str, int] = None
|
usage: dict[str, int] = None
|
||||||
stop_reason: str = ''
|
stop_reason: str = ""
|
||||||
function_call: dict[str, Any] = None
|
function_call: dict[str, Any] = None
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
|
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
|
||||||
return {
|
return {
|
||||||
'sender_type': 'BOT',
|
"sender_type": "BOT",
|
||||||
'sender_name': '专家',
|
"sender_name": "专家",
|
||||||
'text': '',
|
"text": "",
|
||||||
'function_call': self.function_call
|
"function_call": self.function_call,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'sender_type': self.role,
|
"sender_type": self.role,
|
||||||
'sender_name': '我' if self.role == 'USER' else '专家',
|
"sender_name": "我" if self.role == "USER" else "专家",
|
||||||
'text': self.content,
|
"text": self.content,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, content: str, role: str = 'USER') -> None:
|
def __init__(self, content: str, role: str = "USER") -> None:
|
||||||
self.content = content
|
self.content = content
|
||||||
self.role = role
|
self.role = role
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MinimaxProvider(ModelProvider):
|
class MinimaxProvider(ModelProvider):
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -20,11 +25,12 @@ class MinimaxProvider(ModelProvider):
|
|||||||
|
|
||||||
# Use `abab5.5-chat` model for validate,
|
# Use `abab5.5-chat` model for validate,
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='abab5.5-chat',
|
model="abab5.5-chat", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
raise CredentialsValidateFailedError(f'{ex}')
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
|
raise CredentialsValidateFailedError(f"{ex}")
|
||||||
|
|||||||
@ -5,7 +5,10 @@ from typing import Optional
|
|||||||
from requests import post
|
from requests import post
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
EmbeddingUsage,
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -14,8 +17,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.minimax.llm.errors import (
|
from model_providers.core.model_runtime.model_providers.minimax.llm.errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalanceError,
|
InsufficientAccountBalanceError,
|
||||||
@ -30,11 +37,16 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
"""
|
"""
|
||||||
Model class for Minimax text embedding model.
|
Model class for Minimax text embedding model.
|
||||||
"""
|
"""
|
||||||
api_base: str = 'https://api.minimax.chat/v1/embeddings'
|
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
api_base: str = "https://api.minimax.chat/v1/embeddings"
|
||||||
texts: list[str], user: Optional[str] = None) \
|
|
||||||
-> TextEmbeddingResult:
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding model
|
Invoke text embedding model
|
||||||
|
|
||||||
@ -44,55 +56,51 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
api_key = credentials['minimax_api_key']
|
api_key = credentials["minimax_api_key"]
|
||||||
group_id = credentials['minimax_group_id']
|
group_id = credentials["minimax_group_id"]
|
||||||
if model != 'embo-01':
|
if model != "embo-01":
|
||||||
raise ValueError('Invalid model name')
|
raise ValueError("Invalid model name")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise CredentialsValidateFailedError('api_key is required')
|
raise CredentialsValidateFailedError("api_key is required")
|
||||||
url = f'{self.api_base}?GroupId={group_id}'
|
url = f"{self.api_base}?GroupId={group_id}"
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': 'Bearer ' + api_key,
|
"Authorization": "Bearer " + api_key,
|
||||||
'Content-Type': 'application/json'
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {
|
data = {"model": "embo-01", "texts": texts, "type": "db"}
|
||||||
'model': 'embo-01',
|
|
||||||
'texts': texts,
|
|
||||||
'type': 'db'
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(url, headers=headers, data=dumps(data))
|
response = post(url, headers=headers, data=dumps(data))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InvokeConnectionError(str(e))
|
raise InvokeConnectionError(str(e))
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise InvokeServerUnavailableError(response.text)
|
raise InvokeServerUnavailableError(response.text)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = response.json()
|
resp = response.json()
|
||||||
# check if there is an error
|
# check if there is an error
|
||||||
if resp['base_resp']['status_code'] != 0:
|
if resp["base_resp"]["status_code"] != 0:
|
||||||
code = resp['base_resp']['status_code']
|
code = resp["base_resp"]["status_code"]
|
||||||
msg = resp['base_resp']['status_msg']
|
msg = resp["base_resp"]["status_msg"]
|
||||||
self._handle_error(code, msg)
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
embeddings = resp['vectors']
|
embeddings = resp["vectors"]
|
||||||
total_tokens = resp['total_tokens']
|
total_tokens = resp["total_tokens"]
|
||||||
except InvalidAuthenticationError:
|
except InvalidAuthenticationError:
|
||||||
raise InvalidAPIKeyError('Invalid api key')
|
raise InvalidAPIKeyError("Invalid api key")
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
raise InternalServerError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model=model, credentials=credentials, tokens=total_tokens
|
||||||
result = TextEmbeddingResult(
|
|
||||||
model=model,
|
|
||||||
embeddings=embeddings,
|
|
||||||
usage=usage
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
@ -119,9 +127,9 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
except InvalidAPIKeyError:
|
except InvalidAPIKeyError:
|
||||||
raise CredentialsValidateFailedError('Invalid api key')
|
raise CredentialsValidateFailedError("Invalid api key")
|
||||||
|
|
||||||
def _handle_error(self, code: int, msg: str):
|
def _handle_error(self, code: int, msg: str):
|
||||||
if code == 1000 or code == 1001:
|
if code == 1000 or code == 1001:
|
||||||
@ -148,26 +156,20 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [],
|
||||||
],
|
InvokeServerUnavailableError: [InternalServerError],
|
||||||
InvokeServerUnavailableError: [
|
InvokeRateLimitError: [RateLimitReachedError],
|
||||||
InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
RateLimitReachedError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
InsufficientAccountBalanceError,
|
InsufficientAccountBalanceError,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||||
BadRequestError,
|
|
||||||
KeyError
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
|
||||||
@ -181,7 +183,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -192,7 +194,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|||||||
@ -2,24 +2,43 @@ from collections.abc import Generator
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
|
||||||
|
OAIAPICompatLargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
self._add_custom_parameters(credentials)
|
self._add_custom_parameters(credentials)
|
||||||
|
|
||||||
# mistral dose not support user/stop arguments
|
# mistral dose not support user/stop arguments
|
||||||
stop = []
|
stop = []
|
||||||
user = None
|
user = None
|
||||||
|
|
||||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
return super()._invoke(
|
||||||
|
model,
|
||||||
|
credentials,
|
||||||
|
prompt_messages,
|
||||||
|
model_parameters,
|
||||||
|
tools,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
self._add_custom_parameters(credentials)
|
self._add_custom_parameters(credentials)
|
||||||
@ -27,5 +46,5 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_custom_parameters(credentials: dict) -> None:
|
def _add_custom_parameters(credentials: dict) -> None:
|
||||||
credentials['mode'] = 'chat'
|
credentials["mode"] = "chat"
|
||||||
credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
|
credentials["endpoint_url"] = "https://api.mistral.ai/v1"
|
||||||
|
|||||||
@ -1,14 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MistralAIProvider(ModelProvider):
|
class MistralAIProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
@ -20,11 +23,12 @@ class MistralAIProvider(ModelProvider):
|
|||||||
model_instance = self.get_model_instance(ModelType.LLM)
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='open-mistral-7b',
|
model="open-mistral-7b", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -6,11 +6,24 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
ProviderConfig,
|
||||||
from model_providers.core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
ProviderEntity,
|
||||||
from model_providers.core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
SimpleProviderEntity,
|
||||||
from model_providers.core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.schema_validators.model_credential_schema_validator import (
|
||||||
|
ModelCredentialSchemaValidator,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.schema_validators.provider_credential_schema_validator import (
|
||||||
|
ProviderCredentialSchemaValidator,
|
||||||
|
)
|
||||||
|
from model_providers.core.utils.position_helper import (
|
||||||
|
get_position_map,
|
||||||
|
sort_to_dict_by_position_map,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -91,8 +104,9 @@ class ModelProviderFactory:
|
|||||||
|
|
||||||
return filtered_credentials
|
return filtered_credentials
|
||||||
|
|
||||||
def model_credentials_validate(self, provider: str, model_type: ModelType,
|
def model_credentials_validate(
|
||||||
model: str, credentials: dict) -> dict:
|
self, provider: str, model_type: ModelType, model: str, credentials: dict
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Validate model credentials
|
Validate model credentials
|
||||||
|
|
||||||
@ -123,11 +137,12 @@ class ModelProviderFactory:
|
|||||||
|
|
||||||
return filtered_credentials
|
return filtered_credentials
|
||||||
|
|
||||||
def get_models(self,
|
def get_models(
|
||||||
provider: Optional[str] = None,
|
self,
|
||||||
model_type: Optional[ModelType] = None,
|
provider: Optional[str] = None,
|
||||||
provider_configs: Optional[list[ProviderConfig]] = None) \
|
model_type: Optional[ModelType] = None,
|
||||||
-> list[SimpleProviderEntity]:
|
provider_configs: Optional[list[ProviderConfig]] = None,
|
||||||
|
) -> list[SimpleProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get all models for given model type
|
Get all models for given model type
|
||||||
|
|
||||||
@ -142,7 +157,9 @@ class ModelProviderFactory:
|
|||||||
# convert provider_configs to dict
|
# convert provider_configs to dict
|
||||||
provider_credentials_dict = {}
|
provider_credentials_dict = {}
|
||||||
for provider_config in provider_configs:
|
for provider_config in provider_configs:
|
||||||
provider_credentials_dict[provider_config.provider] = provider_config.credentials
|
provider_credentials_dict[
|
||||||
|
provider_config.provider
|
||||||
|
] = provider_config.credentials
|
||||||
|
|
||||||
# traverse all model_provider_extensions
|
# traverse all model_provider_extensions
|
||||||
providers = []
|
providers = []
|
||||||
@ -192,7 +209,7 @@ class ModelProviderFactory:
|
|||||||
# get the provider extension
|
# get the provider extension
|
||||||
model_provider_extension = model_provider_extensions.get(provider)
|
model_provider_extension = model_provider_extensions.get(provider)
|
||||||
if not model_provider_extension:
|
if not model_provider_extension:
|
||||||
raise Exception(f'Invalid provider: {provider}')
|
raise Exception(f"Invalid provider: {provider}")
|
||||||
|
|
||||||
# get the provider instance
|
# get the provider instance
|
||||||
model_provider_instance = model_provider_extension.provider_instance
|
model_provider_instance = model_provider_extension.provider_instance
|
||||||
@ -203,7 +220,6 @@ class ModelProviderFactory:
|
|||||||
if self.model_provider_extensions:
|
if self.model_provider_extensions:
|
||||||
return self.model_provider_extensions
|
return self.model_provider_extensions
|
||||||
|
|
||||||
|
|
||||||
# get the path of current classes
|
# get the path of current classes
|
||||||
current_path = os.path.abspath(__file__)
|
current_path = os.path.abspath(__file__)
|
||||||
model_providers_path = os.path.dirname(current_path)
|
model_providers_path = os.path.dirname(current_path)
|
||||||
@ -212,8 +228,8 @@ class ModelProviderFactory:
|
|||||||
model_provider_dir_paths = [
|
model_provider_dir_paths = [
|
||||||
os.path.join(model_providers_path, model_provider_dir)
|
os.path.join(model_providers_path, model_provider_dir)
|
||||||
for model_provider_dir in os.listdir(model_providers_path)
|
for model_provider_dir in os.listdir(model_providers_path)
|
||||||
if not model_provider_dir.startswith('__')
|
if not model_provider_dir.startswith("__")
|
||||||
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
|
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
|
||||||
]
|
]
|
||||||
|
|
||||||
# get _position.yaml file path
|
# get _position.yaml file path
|
||||||
@ -227,37 +243,54 @@ class ModelProviderFactory:
|
|||||||
|
|
||||||
file_names = os.listdir(model_provider_dir_path)
|
file_names = os.listdir(model_provider_dir_path)
|
||||||
|
|
||||||
if (model_provider_name + '.py') not in file_names:
|
if (model_provider_name + ".py") not in file_names:
|
||||||
logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
|
logger.warning(
|
||||||
|
f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip."
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
|
# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
|
||||||
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
|
py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
|
||||||
spec = importlib.util.spec_from_file_location(f'model_providers.core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path)
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
f"model_providers.core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
|
||||||
|
py_path,
|
||||||
|
)
|
||||||
mod = importlib.util.module_from_spec(spec)
|
mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
|
|
||||||
model_provider_class = None
|
model_provider_class = None
|
||||||
for name, obj in vars(mod).items():
|
for name, obj in vars(mod).items():
|
||||||
if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider:
|
if (
|
||||||
|
isinstance(obj, type)
|
||||||
|
and issubclass(obj, ModelProvider)
|
||||||
|
and obj != ModelProvider
|
||||||
|
):
|
||||||
model_provider_class = obj
|
model_provider_class = obj
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model_provider_class:
|
if not model_provider_class:
|
||||||
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
|
logger.warning(
|
||||||
|
f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip."
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if f'{model_provider_name}.yaml' not in file_names:
|
if f"{model_provider_name}.yaml" not in file_names:
|
||||||
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
|
logger.warning(
|
||||||
|
f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip."
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_providers.append(ModelProviderExtension(
|
model_providers.append(
|
||||||
name=model_provider_name,
|
ModelProviderExtension(
|
||||||
provider_instance=model_provider_class(),
|
name=model_provider_name,
|
||||||
position=position_map.get(model_provider_name)
|
provider_instance=model_provider_class(),
|
||||||
))
|
position=position_map.get(model_provider_name),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
|
sorted_extensions = sort_to_dict_by_position_map(
|
||||||
|
position_map, model_providers, lambda x: x.name
|
||||||
|
)
|
||||||
|
|
||||||
self.model_provider_extensions = sorted_extensions
|
self.model_provider_extensions = sorted_extensions
|
||||||
|
|
||||||
|
|||||||
@ -2,19 +2,39 @@ from collections.abc import Generator
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
|
||||||
|
OAIAPICompatLargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
self._add_custom_parameters(credentials)
|
self._add_custom_parameters(credentials)
|
||||||
user = user[:32] if user else None
|
user = user[:32] if user else None
|
||||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
return super()._invoke(
|
||||||
|
model,
|
||||||
|
credentials,
|
||||||
|
prompt_messages,
|
||||||
|
model_parameters,
|
||||||
|
tools,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
self._add_custom_parameters(credentials)
|
self._add_custom_parameters(credentials)
|
||||||
@ -22,5 +42,5 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_custom_parameters(credentials: dict) -> None:
|
def _add_custom_parameters(credentials: dict) -> None:
|
||||||
credentials['mode'] = 'chat'
|
credentials["mode"] = "chat"
|
||||||
credentials['endpoint_url'] = 'https://api.moonshot.cn/v1'
|
credentials["endpoint_url"] = "https://api.moonshot.cn/v1"
|
||||||
|
|||||||
@ -1,14 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MoonshotProvider(ModelProvider):
|
class MoonshotProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
@ -20,11 +23,12 @@ class MoonshotProvider(ModelProvider):
|
|||||||
model_instance = self.get_model_instance(ModelType.LLM)
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='moonshot-v1-8k',
|
model="moonshot-v1-8k", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -8,7 +8,12 @@ from urllib.parse import urljoin
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMMode,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
@ -39,8 +44,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -50,11 +59,17 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
Model class for Ollama large language model.
|
Model class for Ollama large language model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -75,11 +90,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
@ -100,10 +120,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
if isinstance(first_prompt_message.content, str):
|
if isinstance(first_prompt_message.content, str):
|
||||||
text = first_prompt_message.content
|
text = first_prompt_message.content
|
||||||
else:
|
else:
|
||||||
text = ''
|
text = ""
|
||||||
for message_content in first_prompt_message.content:
|
for message_content in first_prompt_message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
text = message_content.data
|
text = message_content.data
|
||||||
break
|
break
|
||||||
return self._get_num_tokens_by_gpt2(text)
|
return self._get_num_tokens_by_gpt2(text)
|
||||||
@ -121,19 +143,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
prompt_messages=[UserPromptMessage(content="ping")],
|
prompt_messages=[UserPromptMessage(content="ping")],
|
||||||
model_parameters={
|
model_parameters={"num_predict": 5},
|
||||||
'num_predict': 5
|
stream=False,
|
||||||
},
|
|
||||||
stream=False
|
|
||||||
)
|
)
|
||||||
except InvokeError as ex:
|
except InvokeError as ex:
|
||||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
|
raise CredentialsValidateFailedError(
|
||||||
|
f"An error occurred during credentials validation: {ex.description}"
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
raise CredentialsValidateFailedError(
|
||||||
|
f"An error occurred during credentials validation: {str(ex)}"
|
||||||
|
)
|
||||||
|
|
||||||
def _generate(self, model: str, credentials: dict,
|
def _generate(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke llm completion model
|
Invoke llm completion model
|
||||||
|
|
||||||
@ -146,76 +177,89 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {"Content-Type": "application/json"}
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint_url = credentials['base_url']
|
endpoint_url = credentials["base_url"]
|
||||||
if not endpoint_url.endswith('/'):
|
if not endpoint_url.endswith("/"):
|
||||||
endpoint_url += '/'
|
endpoint_url += "/"
|
||||||
|
|
||||||
# prepare the payload for a simple ping to the model
|
# prepare the payload for a simple ping to the model
|
||||||
data = {
|
data = {"model": model, "stream": stream}
|
||||||
'model': model,
|
|
||||||
'stream': stream
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'format' in model_parameters:
|
if "format" in model_parameters:
|
||||||
data['format'] = model_parameters['format']
|
data["format"] = model_parameters["format"]
|
||||||
del model_parameters['format']
|
del model_parameters["format"]
|
||||||
|
|
||||||
data['options'] = model_parameters or {}
|
data["options"] = model_parameters or {}
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
data['stop'] = "\n".join(stop)
|
data["stop"] = "\n".join(stop)
|
||||||
|
|
||||||
completion_type = LLMMode.value_of(credentials['mode'])
|
completion_type = LLMMode.value_of(credentials["mode"])
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
endpoint_url = urljoin(endpoint_url, 'api/chat')
|
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
data["messages"] = [
|
||||||
|
self._convert_prompt_message_to_dict(m) for m in prompt_messages
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
endpoint_url = urljoin(endpoint_url, 'api/generate')
|
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||||
first_prompt_message = prompt_messages[0]
|
first_prompt_message = prompt_messages[0]
|
||||||
if isinstance(first_prompt_message, UserPromptMessage):
|
if isinstance(first_prompt_message, UserPromptMessage):
|
||||||
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
|
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
|
||||||
if isinstance(first_prompt_message.content, str):
|
if isinstance(first_prompt_message.content, str):
|
||||||
data['prompt'] = first_prompt_message.content
|
data["prompt"] = first_prompt_message.content
|
||||||
else:
|
else:
|
||||||
text = ''
|
text = ""
|
||||||
images = []
|
images = []
|
||||||
for message_content in first_prompt_message.content:
|
for message_content in first_prompt_message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
text = message_content.data
|
text = message_content.data
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
message_content = cast(
|
||||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
ImagePromptMessageContent, message_content
|
||||||
|
)
|
||||||
|
image_data = re.sub(
|
||||||
|
r"^data:image\/[a-zA-Z]+;base64,",
|
||||||
|
"",
|
||||||
|
message_content.data,
|
||||||
|
)
|
||||||
images.append(image_data)
|
images.append(image_data)
|
||||||
|
|
||||||
data['prompt'] = text
|
data["prompt"] = text
|
||||||
data['images'] = images
|
data["images"] = images
|
||||||
|
|
||||||
# send a post request to validate the credentials
|
# send a post request to validate the credentials
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
endpoint_url,
|
endpoint_url, headers=headers, json=data, timeout=(10, 60), stream=stream
|
||||||
headers=headers,
|
|
||||||
json=data,
|
|
||||||
timeout=(10, 60),
|
|
||||||
stream=stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response.encoding = "utf-8"
|
response.encoding = "utf-8"
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
raise InvokeError(
|
||||||
|
f"API request failed with status code {response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
return self._handle_generate_stream_response(
|
||||||
|
model, credentials, completion_type, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
return self._handle_generate_response(
|
||||||
|
model, credentials, completion_type, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
def _handle_generate_response(
|
||||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
completion_type: LLMMode,
|
||||||
|
response: requests.Response,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm completion response
|
Handle llm completion response
|
||||||
|
|
||||||
@ -229,14 +273,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
message = response_json.get('message', {})
|
message = response_json.get("message", {})
|
||||||
response_content = message.get('content', '')
|
response_content = message.get("content", "")
|
||||||
else:
|
else:
|
||||||
response_content = response_json['response']
|
response_content = response_json["response"]
|
||||||
|
|
||||||
assistant_message = AssistantPromptMessage(content=response_content)
|
assistant_message = AssistantPromptMessage(content=response_content)
|
||||||
|
|
||||||
if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
|
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||||
# transform usage
|
# transform usage
|
||||||
prompt_tokens = response_json["prompt_eval_count"]
|
prompt_tokens = response_json["prompt_eval_count"]
|
||||||
completion_tokens = response_json["eval_count"]
|
completion_tokens = response_json["eval_count"]
|
||||||
@ -246,7 +290,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
|
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
result = LLMResult(
|
result = LLMResult(
|
||||||
@ -258,8 +304,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
def _handle_generate_stream_response(
|
||||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
completion_type: LLMMode,
|
||||||
|
response: requests.Response,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm completion stream response
|
Handle llm completion stream response
|
||||||
|
|
||||||
@ -270,17 +322,20 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response chunk generator result
|
:return: llm response chunk generator result
|
||||||
"""
|
"""
|
||||||
full_text = ''
|
full_text = ""
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
|
|
||||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
def create_final_llm_result_chunk(
|
||||||
-> LLMResultChunk:
|
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||||
|
) -> LLMResultChunk:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
return LLMResultChunk(
|
return LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
@ -289,11 +344,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
index=index,
|
index=index,
|
||||||
message=message,
|
message=message,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
|
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n"):
|
||||||
if not chunk:
|
if not chunk:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -304,7 +359,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
yield create_final_llm_result_chunk(
|
yield create_final_llm_result_chunk(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=AssistantPromptMessage(content=""),
|
message=AssistantPromptMessage(content=""),
|
||||||
finish_reason="Non-JSON encountered."
|
finish_reason="Non-JSON encountered.",
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
@ -314,55 +369,57 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
if not chunk_json:
|
if not chunk_json:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 'message' not in chunk_json:
|
if "message" not in chunk_json:
|
||||||
text = ''
|
text = ""
|
||||||
else:
|
else:
|
||||||
text = chunk_json.get('message').get('content', '')
|
text = chunk_json.get("message").get("content", "")
|
||||||
else:
|
else:
|
||||||
if not chunk_json:
|
if not chunk_json:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
text = chunk_json['response']
|
text = chunk_json["response"]
|
||||||
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||||
content=text
|
|
||||||
)
|
|
||||||
|
|
||||||
full_text += text
|
full_text += text
|
||||||
|
|
||||||
if chunk_json['done']:
|
if chunk_json["done"]:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
|
if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json:
|
||||||
# transform usage
|
# transform usage
|
||||||
prompt_tokens = chunk_json["prompt_eval_count"]
|
prompt_tokens = chunk_json["prompt_eval_count"]
|
||||||
completion_tokens = chunk_json["eval_count"]
|
completion_tokens = chunk_json["eval_count"]
|
||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||||
|
prompt_messages[0].content
|
||||||
|
)
|
||||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=chunk_json['model'],
|
model=chunk_json["model"],
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason='stop',
|
finish_reason="stop",
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=chunk_json['model'],
|
model=chunk_json["model"],
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
@ -376,15 +433,21 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
if isinstance(message.content, str):
|
if isinstance(message.content, str):
|
||||||
message_dict = {"role": "user", "content": message.content}
|
message_dict = {"role": "user", "content": message.content}
|
||||||
else:
|
else:
|
||||||
text = ''
|
text = ""
|
||||||
images = []
|
images = []
|
||||||
for message_content in message.content:
|
for message_content in message.content:
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
TextPromptMessageContent, message_content
|
||||||
|
)
|
||||||
text = message_content.data
|
text = message_content.data
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
message_content = cast(
|
||||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
ImagePromptMessageContent, message_content
|
||||||
|
)
|
||||||
|
image_data = re.sub(
|
||||||
|
r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data
|
||||||
|
)
|
||||||
images.append(image_data)
|
images.append(image_data)
|
||||||
|
|
||||||
message_dict = {"role": "user", "content": text, "images": images}
|
message_dict = {"role": "user", "content": text, "images": images}
|
||||||
@ -414,7 +477,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> AIModelEntity:
|
||||||
"""
|
"""
|
||||||
Get customizable model schema.
|
Get customizable model schema.
|
||||||
|
|
||||||
@ -425,20 +490,19 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
extras = {}
|
extras = {}
|
||||||
|
|
||||||
if 'vision_support' in credentials and credentials['vision_support'] == 'true':
|
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||||
extras['features'] = [ModelFeature.VISION]
|
extras["features"] = [ModelFeature.VISION]
|
||||||
|
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
label=I18nObject(
|
label=I18nObject(zh_Hans=model, en_US=model),
|
||||||
zh_Hans=model,
|
|
||||||
en_US=model
|
|
||||||
),
|
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
|
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||||
|
credentials.get("context_size", 4096)
|
||||||
|
),
|
||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
@ -446,161 +510,191 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
use_template=DefaultParameterName.TEMPERATURE.value,
|
use_template=DefaultParameterName.TEMPERATURE.value,
|
||||||
label=I18nObject(en_US="Temperature"),
|
label=I18nObject(en_US="Temperature"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="The temperature of the model. "
|
help=I18nObject(
|
||||||
"Increasing the temperature will make the model answer "
|
en_US="The temperature of the model. "
|
||||||
"more creatively. (Default: 0.8)"),
|
"Increasing the temperature will make the model answer "
|
||||||
|
"more creatively. (Default: 0.8)"
|
||||||
|
),
|
||||||
default=0.8,
|
default=0.8,
|
||||||
min=0,
|
min=0,
|
||||||
max=2
|
max=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name=DefaultParameterName.TOP_P.value,
|
name=DefaultParameterName.TOP_P.value,
|
||||||
use_template=DefaultParameterName.TOP_P.value,
|
use_template=DefaultParameterName.TOP_P.value,
|
||||||
label=I18nObject(en_US="Top P"),
|
label=I18nObject(en_US="Top P"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
help=I18nObject(
|
||||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||||
"focused and conservative text. (Default: 0.9)"),
|
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||||
|
"focused and conservative text. (Default: 0.9)"
|
||||||
|
),
|
||||||
default=0.9,
|
default=0.9,
|
||||||
min=0,
|
min=0,
|
||||||
max=1
|
max=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name="top_k",
|
name="top_k",
|
||||||
label=I18nObject(en_US="Top K"),
|
label=I18nObject(en_US="Top K"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Reduces the probability of generating nonsense. "
|
help=I18nObject(
|
||||||
"A higher value (e.g. 100) will give more diverse answers, "
|
en_US="Reduces the probability of generating nonsense. "
|
||||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
|
"A higher value (e.g. 100) will give more diverse answers, "
|
||||||
|
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||||
|
),
|
||||||
default=40,
|
default=40,
|
||||||
min=1,
|
min=1,
|
||||||
max=100
|
max=100,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='repeat_penalty',
|
name="repeat_penalty",
|
||||||
label=I18nObject(en_US="Repeat Penalty"),
|
label=I18nObject(en_US="Repeat Penalty"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
|
help=I18nObject(
|
||||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
en_US="Sets how strongly to penalize repetitions. "
|
||||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
|
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||||
|
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||||
|
),
|
||||||
default=1.1,
|
default=1.1,
|
||||||
min=-2,
|
min=-2,
|
||||||
max=2
|
max=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='num_predict',
|
name="num_predict",
|
||||||
use_template='max_tokens',
|
use_template="max_tokens",
|
||||||
label=I18nObject(en_US="Num Predict"),
|
label=I18nObject(en_US="Num Predict"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
|
help=I18nObject(
|
||||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"),
|
en_US="Maximum number of tokens to predict when generating text. "
|
||||||
|
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||||
|
),
|
||||||
default=128,
|
default=128,
|
||||||
min=-2,
|
min=-2,
|
||||||
max=int(credentials.get('max_tokens', 4096)),
|
max=int(credentials.get("max_tokens", 4096)),
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='mirostat',
|
name="mirostat",
|
||||||
label=I18nObject(en_US="Mirostat sampling"),
|
label=I18nObject(en_US="Mirostat sampling"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
|
help=I18nObject(
|
||||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
|
en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||||
|
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||||
|
),
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2
|
max=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='mirostat_eta',
|
name="mirostat_eta",
|
||||||
label=I18nObject(en_US="Mirostat Eta"),
|
label=I18nObject(en_US="Mirostat Eta"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
|
help=I18nObject(
|
||||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
en_US="Influences how quickly the algorithm responds to feedback from "
|
||||||
"while a higher learning rate will make the algorithm more responsive. "
|
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||||
"(Default: 0.1)"),
|
"while a higher learning rate will make the algorithm more responsive. "
|
||||||
|
"(Default: 0.1)"
|
||||||
|
),
|
||||||
default=0.1,
|
default=0.1,
|
||||||
precision=1
|
precision=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='mirostat_tau',
|
name="mirostat_tau",
|
||||||
label=I18nObject(en_US="Mirostat Tau"),
|
label=I18nObject(en_US="Mirostat Tau"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
|
help=I18nObject(
|
||||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"),
|
en_US="Controls the balance between coherence and diversity of the output. "
|
||||||
|
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||||
|
),
|
||||||
default=5.0,
|
default=5.0,
|
||||||
precision=1
|
precision=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='num_ctx',
|
name="num_ctx",
|
||||||
label=I18nObject(en_US="Size of context window"),
|
label=I18nObject(en_US="Size of context window"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
|
help=I18nObject(
|
||||||
"(Default: 2048)"),
|
en_US="Sets the size of the context window used to generate the next token. "
|
||||||
|
"(Default: 2048)"
|
||||||
|
),
|
||||||
default=2048,
|
default=2048,
|
||||||
min=1
|
|
||||||
),
|
|
||||||
ParameterRule(
|
|
||||||
name='num_gpu',
|
|
||||||
label=I18nObject(en_US="Num GPU"),
|
|
||||||
type=ParameterType.INT,
|
|
||||||
help=I18nObject(en_US="The number of layers to send to the GPU(s). "
|
|
||||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."),
|
|
||||||
default=1,
|
|
||||||
min=0,
|
|
||||||
max=1
|
|
||||||
),
|
|
||||||
ParameterRule(
|
|
||||||
name='num_thread',
|
|
||||||
label=I18nObject(en_US="Num Thread"),
|
|
||||||
type=ParameterType.INT,
|
|
||||||
help=I18nObject(en_US="Sets the number of threads to use during computation. "
|
|
||||||
"By default, Ollama will detect this for optimal performance. "
|
|
||||||
"It is recommended to set this value to the number of physical CPU cores "
|
|
||||||
"your system has (as opposed to the logical number of cores)."),
|
|
||||||
min=1,
|
min=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='repeat_last_n',
|
name="num_gpu",
|
||||||
|
label=I18nObject(en_US="Num GPU"),
|
||||||
|
type=ParameterType.INT,
|
||||||
|
help=I18nObject(
|
||||||
|
en_US="The number of layers to send to the GPU(s). "
|
||||||
|
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||||
|
),
|
||||||
|
default=1,
|
||||||
|
min=0,
|
||||||
|
max=1,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="num_thread",
|
||||||
|
label=I18nObject(en_US="Num Thread"),
|
||||||
|
type=ParameterType.INT,
|
||||||
|
help=I18nObject(
|
||||||
|
en_US="Sets the number of threads to use during computation. "
|
||||||
|
"By default, Ollama will detect this for optimal performance. "
|
||||||
|
"It is recommended to set this value to the number of physical CPU cores "
|
||||||
|
"your system has (as opposed to the logical number of cores)."
|
||||||
|
),
|
||||||
|
min=1,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="repeat_last_n",
|
||||||
label=I18nObject(en_US="Repeat last N"),
|
label=I18nObject(en_US="Repeat last N"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
|
help=I18nObject(
|
||||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"),
|
en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||||
|
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||||
|
),
|
||||||
default=64,
|
default=64,
|
||||||
min=-1
|
min=-1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='tfs_z',
|
name="tfs_z",
|
||||||
label=I18nObject(en_US="TFS Z"),
|
label=I18nObject(en_US="TFS Z"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
help=I18nObject(
|
||||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||||
"while a value of 1.0 disables this setting. (default: 1)"),
|
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||||
|
"while a value of 1.0 disables this setting. (default: 1)"
|
||||||
|
),
|
||||||
default=1,
|
default=1,
|
||||||
precision=1
|
precision=1,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='seed',
|
name="seed",
|
||||||
label=I18nObject(en_US="Seed"),
|
label=I18nObject(en_US="Seed"),
|
||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
|
help=I18nObject(
|
||||||
"a specific number will make the model generate the same text for "
|
en_US="Sets the random number seed to use for generation. Setting this to "
|
||||||
"the same prompt. (Default: 0)"),
|
"a specific number will make the model generate the same text for "
|
||||||
default=0
|
"the same prompt. (Default: 0)"
|
||||||
|
),
|
||||||
|
default=0,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name='format',
|
name="format",
|
||||||
label=I18nObject(en_US="Format"),
|
label=I18nObject(en_US="Format"),
|
||||||
type=ParameterType.STRING,
|
type=ParameterType.STRING,
|
||||||
help=I18nObject(en_US="the format to return a response in."
|
help=I18nObject(
|
||||||
" Currently the only accepted value is json."),
|
en_US="the format to return a response in."
|
||||||
options=['json'],
|
" Currently the only accepted value is json."
|
||||||
)
|
),
|
||||||
|
options=["json"],
|
||||||
|
),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=Decimal(credentials.get('input_price', 0)),
|
input=Decimal(credentials.get("input_price", 0)),
|
||||||
output=Decimal(credentials.get('output_price', 0)),
|
output=Decimal(credentials.get("output_price", 0)),
|
||||||
unit=Decimal(credentials.get('unit', 0)),
|
unit=Decimal(credentials.get("unit", 0)),
|
||||||
currency=credentials.get('currency', "USD")
|
currency=credentials.get("currency", "USD"),
|
||||||
),
|
),
|
||||||
**extras
|
**extras,
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
@ -628,10 +722,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
|||||||
],
|
],
|
||||||
InvokeServerUnavailableError: [
|
InvokeServerUnavailableError: [
|
||||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||||
requests.exceptions.HTTPError # Server Error
|
requests.exceptions.HTTPError, # Server Error
|
||||||
],
|
],
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [
|
||||||
requests.exceptions.ConnectTimeout, # Timeout
|
requests.exceptions.ConnectTimeout, # Timeout
|
||||||
requests.exceptions.ReadTimeout # Timeout
|
requests.exceptions.ReadTimeout, # Timeout
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProvider(ModelProvider):
|
class OpenAIProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
|
|||||||
@ -17,7 +17,10 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
PriceConfig,
|
PriceConfig,
|
||||||
PriceType,
|
PriceType,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
|
EmbeddingUsage,
|
||||||
|
TextEmbeddingResult,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
InvokeBadRequestError,
|
InvokeBadRequestError,
|
||||||
@ -26,8 +29,12 @@ from model_providers.core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -37,9 +44,13 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
Model class for an Ollama text embedding model.
|
Model class for an Ollama text embedding model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
texts: list[str], user: Optional[str] = None) \
|
self,
|
||||||
-> TextEmbeddingResult:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding model
|
Invoke text embedding model
|
||||||
|
|
||||||
@ -51,15 +62,13 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Prepare headers and payload for the request
|
# Prepare headers and payload for the request
|
||||||
headers = {
|
headers = {"Content-Type": "application/json"}
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint_url = credentials.get('base_url')
|
endpoint_url = credentials.get("base_url")
|
||||||
if not endpoint_url.endswith('/'):
|
if not endpoint_url.endswith("/"):
|
||||||
endpoint_url += '/'
|
endpoint_url += "/"
|
||||||
|
|
||||||
endpoint_url = urljoin(endpoint_url, 'api/embeddings')
|
endpoint_url = urljoin(endpoint_url, "api/embeddings")
|
||||||
|
|
||||||
# get model properties
|
# get model properties
|
||||||
context_size = self._get_context_size(model, credentials)
|
context_size = self._get_context_size(model, credentials)
|
||||||
@ -74,7 +83,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
if num_tokens >= context_size:
|
if num_tokens >= context_size:
|
||||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||||
# if num tokens is larger than context length, only use the start
|
# if num tokens is larger than context length, only use the start
|
||||||
inputs.append(text[0: cutoff])
|
inputs.append(text[0:cutoff])
|
||||||
else:
|
else:
|
||||||
inputs.append(text)
|
inputs.append(text)
|
||||||
|
|
||||||
@ -83,8 +92,8 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
for text in inputs:
|
for text in inputs:
|
||||||
# Prepare the payload for the request
|
# Prepare the payload for the request
|
||||||
payload = {
|
payload = {
|
||||||
'prompt': text,
|
"prompt": text,
|
||||||
'model': model,
|
"model": model,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the request to the OpenAI API
|
# Make the request to the OpenAI API
|
||||||
@ -92,14 +101,14 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
endpoint_url,
|
endpoint_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(payload),
|
data=json.dumps(payload),
|
||||||
timeout=(10, 300)
|
timeout=(10, 300),
|
||||||
)
|
)
|
||||||
|
|
||||||
response.raise_for_status() # Raise an exception for HTTP errors
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
|
|
||||||
# Extract embeddings and used tokens from the response
|
# Extract embeddings and used tokens from the response
|
||||||
embeddings = response_data['embedding']
|
embeddings = response_data["embedding"]
|
||||||
embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
|
embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
used_tokens += embedding_used_tokens
|
||||||
@ -107,15 +116,11 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
|
|
||||||
# calc usage
|
# calc usage
|
||||||
usage = self._calc_response_usage(
|
usage = self._calc_response_usage(
|
||||||
model=model,
|
model=model, credentials=credentials, tokens=used_tokens
|
||||||
credentials=credentials,
|
|
||||||
tokens=used_tokens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
return TextEmbeddingResult(
|
||||||
embeddings=batched_embeddings,
|
embeddings=batched_embeddings, usage=usage, model=model
|
||||||
usage=usage,
|
|
||||||
model=model
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
@ -138,19 +143,21 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._invoke(
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
model=model,
|
|
||||||
credentials=credentials,
|
|
||||||
texts=['ping']
|
|
||||||
)
|
|
||||||
except InvokeError as ex:
|
except InvokeError as ex:
|
||||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
|
raise CredentialsValidateFailedError(
|
||||||
|
f"An error occurred during credentials validation: {ex.description}"
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
raise CredentialsValidateFailedError(
|
||||||
|
f"An error occurred during credentials validation: {str(ex)}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> AIModelEntity:
|
||||||
"""
|
"""
|
||||||
generate custom model entities from credentials
|
generate custom model entities from credentials
|
||||||
"""
|
"""
|
||||||
entity = AIModelEntity(
|
entity = AIModelEntity(
|
||||||
model=model,
|
model=model,
|
||||||
@ -158,20 +165,22 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")),
|
||||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||||
},
|
},
|
||||||
parameter_rules=[],
|
parameter_rules=[],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=Decimal(credentials.get('input_price', 0)),
|
input=Decimal(credentials.get("input_price", 0)),
|
||||||
unit=Decimal(credentials.get('unit', 0)),
|
unit=Decimal(credentials.get("unit", 0)),
|
||||||
currency=credentials.get('currency', "USD")
|
currency=credentials.get("currency", "USD"),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
|
||||||
@ -185,7 +194,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -196,7 +205,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
@ -224,10 +233,10 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
],
|
],
|
||||||
InvokeServerUnavailableError: [
|
InvokeServerUnavailableError: [
|
||||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||||
requests.exceptions.HTTPError # Server Error
|
requests.exceptions.HTTPError, # Server Error
|
||||||
],
|
],
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [
|
||||||
requests.exceptions.ConnectTimeout, # Timeout
|
requests.exceptions.ConnectTimeout, # Timeout
|
||||||
requests.exceptions.ReadTimeout # Timeout
|
requests.exceptions.ReadTimeout, # Timeout
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,17 +20,17 @@ class _CommonOpenAI:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
credentials_kwargs = {
|
credentials_kwargs = {
|
||||||
"api_key": credentials['openai_api_key'],
|
"api_key": credentials["openai_api_key"],
|
||||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||||
"max_retries": 1,
|
"max_retries": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
if 'openai_api_base' in credentials and credentials['openai_api_base']:
|
if "openai_api_base" in credentials and credentials["openai_api_base"]:
|
||||||
credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
|
credentials["openai_api_base"] = credentials["openai_api_base"].rstrip("/")
|
||||||
credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
|
credentials_kwargs["base_url"] = credentials["openai_api_base"] + "/v1"
|
||||||
|
|
||||||
if 'openai_organization' in credentials:
|
if "openai_organization" in credentials:
|
||||||
credentials_kwargs['organization'] = credentials['openai_organization']
|
credentials_kwargs["organization"] = credentials["openai_organization"]
|
||||||
|
|
||||||
return credentials_kwargs
|
return credentials_kwargs
|
||||||
|
|
||||||
@ -45,24 +45,17 @@ class _CommonOpenAI:
|
|||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||||
openai.APIConnectionError,
|
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||||
openai.APITimeoutError
|
InvokeRateLimitError: [openai.RateLimitError],
|
||||||
],
|
|
||||||
InvokeServerUnavailableError: [
|
|
||||||
openai.InternalServerError
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [
|
|
||||||
openai.RateLimitError
|
|
||||||
],
|
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
openai.AuthenticationError,
|
openai.AuthenticationError,
|
||||||
openai.PermissionDeniedError
|
openai.PermissionDeniedError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [
|
InvokeBadRequestError: [
|
||||||
openai.BadRequestError,
|
openai.BadRequestError,
|
||||||
openai.NotFoundError,
|
openai.NotFoundError,
|
||||||
openai.UnprocessableEntityError,
|
openai.UnprocessableEntityError,
|
||||||
openai.APIError
|
openai.APIError,
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -4,9 +4,15 @@ from openai import OpenAI
|
|||||||
from openai.types import ModerationCreateResponse
|
from openai.types import ModerationCreateResponse
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey
|
from model_providers.core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.moderation_model import (
|
||||||
|
ModerationModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.openai._common import (
|
||||||
|
_CommonOpenAI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
||||||
@ -14,9 +20,9 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
|||||||
Model class for OpenAI text moderation model.
|
Model class for OpenAI text moderation model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
text: str, user: Optional[str] = None) \
|
self, model: str, credentials: dict, text: str, user: Optional[str] = None
|
||||||
-> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Invoke moderation model
|
Invoke moderation model
|
||||||
|
|
||||||
@ -34,13 +40,18 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
|||||||
|
|
||||||
# chars per chunk
|
# chars per chunk
|
||||||
length = self._get_max_characters_per_chunk(model, credentials)
|
length = self._get_max_characters_per_chunk(model, credentials)
|
||||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
text_chunks = [text[i : i + length] for i in range(0, len(text), length)]
|
||||||
|
|
||||||
max_text_chunks = self._get_max_chunks(model, credentials)
|
max_text_chunks = self._get_max_chunks(model, credentials)
|
||||||
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
chunks = [
|
||||||
|
text_chunks[i : i + max_text_chunks]
|
||||||
|
for i in range(0, len(text_chunks), max_text_chunks)
|
||||||
|
]
|
||||||
|
|
||||||
for text_chunk in chunks:
|
for text_chunk in chunks:
|
||||||
moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk)
|
moderation_result = self._moderation_invoke(
|
||||||
|
model=model, client=client, texts=text_chunk
|
||||||
|
)
|
||||||
|
|
||||||
for result in moderation_result.results:
|
for result in moderation_result.results:
|
||||||
if result.flagged is True:
|
if result.flagged is True:
|
||||||
@ -65,12 +76,14 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
|||||||
self._moderation_invoke(
|
self._moderation_invoke(
|
||||||
model=model,
|
model=model,
|
||||||
client=client,
|
client=client,
|
||||||
texts=['ping'],
|
texts=["ping"],
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _moderation_invoke(self, model: str, client: OpenAI, texts: list[str]) -> ModerationCreateResponse:
|
def _moderation_invoke(
|
||||||
|
self, model: str, client: OpenAI, texts: list[str]
|
||||||
|
) -> ModerationCreateResponse:
|
||||||
"""
|
"""
|
||||||
Invoke moderation model
|
Invoke moderation model
|
||||||
|
|
||||||
@ -94,8 +107,14 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
|||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties:
|
if (
|
||||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
|
model_schema
|
||||||
|
and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK
|
||||||
|
in model_schema.model_properties
|
||||||
|
):
|
||||||
|
return model_schema.model_properties[
|
||||||
|
ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK
|
||||||
|
]
|
||||||
|
|
||||||
return 2000
|
return 2000
|
||||||
|
|
||||||
@ -109,7 +128,10 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel):
|
|||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
if (
|
||||||
|
model_schema
|
||||||
|
and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
|
||||||
|
):
|
||||||
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
|||||||
@ -1,14 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.model_provider import (
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProvider(ModelProvider):
|
class OpenAIProvider(ModelProvider):
|
||||||
|
|
||||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Validate provider credentials
|
Validate provider credentials
|
||||||
@ -22,11 +25,12 @@ class OpenAIProvider(ModelProvider):
|
|||||||
# Use `gpt-3.5-turbo` model for validate,
|
# Use `gpt-3.5-turbo` model for validate,
|
||||||
# no matter what model you pass in, text completion model or chat model
|
# no matter what model you pass in, text completion model or chat model
|
||||||
model_instance.validate_credentials(
|
model_instance.validate_credentials(
|
||||||
model='gpt-3.5-turbo',
|
model="gpt-3.5-turbo", credentials=credentials
|
||||||
credentials=credentials
|
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
raise ex
|
raise ex
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
logger.exception(
|
||||||
|
f"{self.get_provider_schema().provider} credentials validate failed"
|
||||||
|
)
|
||||||
raise ex
|
raise ex
|
||||||
|
|||||||
@ -2,9 +2,15 @@ from typing import IO, Optional
|
|||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import (
|
||||||
|
Speech2TextModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.openai._common import (
|
||||||
|
_CommonOpenAI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel):
|
class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel):
|
||||||
@ -12,9 +18,9 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel):
|
|||||||
Model class for OpenAI Speech to text model.
|
Model class for OpenAI Speech to text model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
file: IO[bytes], user: Optional[str] = None) \
|
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
|
||||||
-> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Invoke speech2text model
|
Invoke speech2text model
|
||||||
|
|
||||||
@ -37,12 +43,14 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel):
|
|||||||
try:
|
try:
|
||||||
audio_file_path = self._get_demo_file_path()
|
audio_file_path = self._get_demo_file_path()
|
||||||
|
|
||||||
with open(audio_file_path, 'rb') as audio_file:
|
with open(audio_file_path, "rb") as audio_file:
|
||||||
self._speech2text_invoke(model, credentials, audio_file)
|
self._speech2text_invoke(model, credentials, audio_file)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
|
def _speech2text_invoke(
|
||||||
|
self, model: str, credentials: dict, file: IO[bytes]
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Invoke speech2text model
|
Invoke speech2text model
|
||||||
|
|
||||||
|
|||||||
@ -7,10 +7,19 @@ import tiktoken
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
from model_providers.core.model_runtime.entities.model_entities import PriceType
|
||||||
from model_providers.core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from model_providers.core.model_runtime.entities.text_embedding_entities import (
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
EmbeddingUsage,
|
||||||
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
TextEmbeddingResult,
|
||||||
from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
)
|
||||||
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import (
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.openai._common import (
|
||||||
|
_CommonOpenAI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||||
@ -18,9 +27,13 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
Model class for OpenAI text embedding model.
|
Model class for OpenAI text embedding model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
texts: list[str], user: Optional[str] = None) \
|
self,
|
||||||
-> TextEmbeddingResult:
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
"""
|
"""
|
||||||
Invoke text embedding model
|
Invoke text embedding model
|
||||||
|
|
||||||
@ -37,9 +50,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
|
|
||||||
extra_model_kwargs = {}
|
extra_model_kwargs = {}
|
||||||
if user:
|
if user:
|
||||||
extra_model_kwargs['user'] = user
|
extra_model_kwargs["user"] = user
|
||||||
|
|
||||||
extra_model_kwargs['encoding_format'] = 'base64'
|
extra_model_kwargs["encoding_format"] = "base64"
|
||||||
|
|
||||||
# get model properties
|
# get model properties
|
||||||
context_size = self._get_context_size(model, credentials)
|
context_size = self._get_context_size(model, credentials)
|
||||||
@ -56,11 +69,9 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
enc = tiktoken.get_encoding("cl100k_base")
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
token = enc.encode(
|
token = enc.encode(text)
|
||||||
text
|
|
||||||
)
|
|
||||||
for j in range(0, len(token), context_size):
|
for j in range(0, len(token), context_size):
|
||||||
tokens += [token[j: j + context_size]]
|
tokens += [token[j : j + context_size]]
|
||||||
indices += [i]
|
indices += [i]
|
||||||
|
|
||||||
batched_embeddings = []
|
batched_embeddings = []
|
||||||
@ -71,8 +82,8 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||||
model=model,
|
model=model,
|
||||||
client=client,
|
client=client,
|
||||||
texts=tokens[i: i + max_chunks],
|
texts=tokens[i : i + max_chunks],
|
||||||
extra_model_kwargs=extra_model_kwargs
|
extra_model_kwargs=extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
used_tokens += embedding_used_tokens
|
||||||
@ -91,7 +102,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
client=client,
|
client=client,
|
||||||
texts="",
|
texts="",
|
||||||
extra_model_kwargs=extra_model_kwargs
|
extra_model_kwargs=extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
used_tokens += embedding_used_tokens
|
used_tokens += embedding_used_tokens
|
||||||
@ -102,16 +113,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
|
|
||||||
# calc usage
|
# calc usage
|
||||||
usage = self._calc_response_usage(
|
usage = self._calc_response_usage(
|
||||||
model=model,
|
model=model, credentials=credentials, tokens=used_tokens
|
||||||
credentials=credentials,
|
|
||||||
tokens=used_tokens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return TextEmbeddingResult(
|
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
|
||||||
embeddings=embeddings,
|
|
||||||
usage=usage,
|
|
||||||
model=model
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
"""
|
"""
|
||||||
@ -153,16 +158,18 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
|
|
||||||
# call embedding model
|
# call embedding model
|
||||||
self._embedding_invoke(
|
self._embedding_invoke(
|
||||||
model=model,
|
model=model, client=client, texts=["ping"], extra_model_kwargs={}
|
||||||
client=client,
|
|
||||||
texts=['ping'],
|
|
||||||
extra_model_kwargs={}
|
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str],
|
def _embedding_invoke(
|
||||||
extra_model_kwargs: dict) -> tuple[list[list[float]], int]:
|
self,
|
||||||
|
model: str,
|
||||||
|
client: OpenAI,
|
||||||
|
texts: Union[list[str], str],
|
||||||
|
extra_model_kwargs: dict,
|
||||||
|
) -> tuple[list[list[float]], int]:
|
||||||
"""
|
"""
|
||||||
Invoke embedding model
|
Invoke embedding model
|
||||||
|
|
||||||
@ -179,14 +186,26 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
**extra_model_kwargs,
|
**extra_model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64':
|
if (
|
||||||
|
"encoding_format" in extra_model_kwargs
|
||||||
|
and extra_model_kwargs["encoding_format"] == "base64"
|
||||||
|
):
|
||||||
# decode base64 embedding
|
# decode base64 embedding
|
||||||
return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data],
|
return (
|
||||||
response.usage.total_tokens)
|
[
|
||||||
|
list(
|
||||||
|
np.frombuffer(base64.b64decode(data.embedding), dtype="float32")
|
||||||
|
)
|
||||||
|
for data in response.data
|
||||||
|
],
|
||||||
|
response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return [data.embedding for data in response.data], response.usage.total_tokens
|
return [data.embedding for data in response.data], response.usage.total_tokens
|
||||||
|
|
||||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
def _calc_response_usage(
|
||||||
|
self, model: str, credentials: dict, tokens: int
|
||||||
|
) -> EmbeddingUsage:
|
||||||
"""
|
"""
|
||||||
Calculate response usage
|
Calculate response usage
|
||||||
|
|
||||||
@ -200,7 +219,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
price_type=PriceType.INPUT,
|
price_type=PriceType.INPUT,
|
||||||
tokens=tokens
|
tokens=tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
@ -211,7 +230,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
|||||||
price_unit=input_price_info.unit,
|
price_unit=input_price_info.unit,
|
||||||
total_price=input_price_info.total_amount,
|
total_price=input_price_info.total_amount,
|
||||||
currency=input_price_info.currency,
|
currency=input_price_info.currency,
|
||||||
latency=time.perf_counter() - self.started_at
|
latency=time.perf_counter() - self.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
return usage
|
return usage
|
||||||
|
|||||||
@ -3,13 +3,18 @@ from functools import reduce
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
from model_providers.core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
|
CredentialsValidateFailedError,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
from model_providers.core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
from model_providers.core.model_runtime.model_providers.openai._common import (
|
||||||
|
_CommonOpenAI,
|
||||||
|
)
|
||||||
from model_providers.extensions.ext_storage import storage
|
from model_providers.extensions.ext_storage import storage
|
||||||
|
|
||||||
|
|
||||||
@ -18,8 +23,16 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
Model class for OpenAI Speech to text model.
|
Model class for OpenAI Speech to text model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
def _invoke(
|
||||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
self,
|
||||||
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
streaming: bool,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> any:
|
||||||
"""
|
"""
|
||||||
_invoke text2speech model
|
_invoke text2speech model
|
||||||
|
|
||||||
@ -33,18 +46,33 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
:return: text translated to audio file
|
:return: text translated to audio file
|
||||||
"""
|
"""
|
||||||
audio_type = self._get_model_audio_type(model, credentials)
|
audio_type = self._get_model_audio_type(model, credentials)
|
||||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
if not voice or voice not in [
|
||||||
|
d["value"]
|
||||||
|
for d in self.get_tts_model_voices(model=model, credentials=credentials)
|
||||||
|
]:
|
||||||
voice = self._get_model_default_voice(model, credentials)
|
voice = self._get_model_default_voice(model, credentials)
|
||||||
if streaming:
|
if streaming:
|
||||||
return StreamingResponse(self._tts_invoke_streaming(model=model,
|
return StreamingResponse(
|
||||||
credentials=credentials,
|
self._tts_invoke_streaming(
|
||||||
content_text=content_text,
|
model=model,
|
||||||
tenant_id=tenant_id,
|
credentials=credentials,
|
||||||
voice=voice), media_type='text/event-stream')
|
content_text=content_text,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
voice=voice,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
return self._tts_invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
content_text=content_text,
|
||||||
|
voice=voice,
|
||||||
|
)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
def validate_credentials(
|
||||||
|
self, model: str, credentials: dict, user: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
validate credentials text2speech model
|
validate credentials text2speech model
|
||||||
|
|
||||||
@ -57,13 +85,15 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
self._tts_invoke(
|
self._tts_invoke(
|
||||||
model=model,
|
model=model,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
content_text='Hello Dify!',
|
content_text="Hello Dify!",
|
||||||
voice=self._get_model_default_voice(model, credentials),
|
voice=self._get_model_default_voice(model, credentials),
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> StreamingResponse:
|
def _tts_invoke(
|
||||||
|
self, model: str, credentials: dict, content_text: str, voice: str
|
||||||
|
) -> StreamingResponse:
|
||||||
"""
|
"""
|
||||||
_tts_invoke text2speech model
|
_tts_invoke text2speech model
|
||||||
|
|
||||||
@ -77,13 +107,25 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
word_limit = self._get_model_word_limit(model, credentials)
|
word_limit = self._get_model_word_limit(model, credentials)
|
||||||
max_workers = self._get_model_workers_limit(model, credentials)
|
max_workers = self._get_model_workers_limit(model, credentials)
|
||||||
try:
|
try:
|
||||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
sentences = list(
|
||||||
|
self._split_text_into_sentences(text=content_text, limit=word_limit)
|
||||||
|
)
|
||||||
audio_bytes_list = list()
|
audio_bytes_list = list()
|
||||||
|
|
||||||
# Create a thread pool and map the function to the list of sentences
|
# Create a thread pool and map the function to the list of sentences
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(
|
||||||
futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
|
max_workers=max_workers
|
||||||
credentials=credentials) for sentence in sentences]
|
) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
self._process_sentence,
|
||||||
|
sentence=sentence,
|
||||||
|
model=model,
|
||||||
|
voice=voice,
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
for sentence in sentences
|
||||||
|
]
|
||||||
for future in futures:
|
for future in futures:
|
||||||
try:
|
try:
|
||||||
if future.result():
|
if future.result():
|
||||||
@ -92,8 +134,11 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
if len(audio_bytes_list) > 0:
|
if len(audio_bytes_list) > 0:
|
||||||
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
audio_segments = [
|
||||||
audio_bytes_list if audio_bytes]
|
AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type)
|
||||||
|
for audio_bytes in audio_bytes_list
|
||||||
|
if audio_bytes
|
||||||
|
]
|
||||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||||
buffer: BytesIO = BytesIO()
|
buffer: BytesIO = BytesIO()
|
||||||
combined_segment.export(buffer, format=audio_type)
|
combined_segment.export(buffer, format=audio_type)
|
||||||
@ -103,8 +148,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
# Todo: To improve the streaming function
|
# Todo: To improve the streaming function
|
||||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
def _tts_invoke_streaming(
|
||||||
voice: str) -> any:
|
self,
|
||||||
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
) -> any:
|
||||||
"""
|
"""
|
||||||
_tts_invoke_streaming text2speech model
|
_tts_invoke_streaming text2speech model
|
||||||
|
|
||||||
@ -117,24 +168,29 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
"""
|
"""
|
||||||
# transform credentials to kwargs for model instance
|
# transform credentials to kwargs for model instance
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
if not voice or voice not in self.get_tts_model_voices(
|
||||||
|
model=model, credentials=credentials
|
||||||
|
):
|
||||||
voice = self._get_model_default_voice(model, credentials)
|
voice = self._get_model_default_voice(model, credentials)
|
||||||
word_limit = self._get_model_word_limit(model, credentials)
|
word_limit = self._get_model_word_limit(model, credentials)
|
||||||
audio_type = self._get_model_audio_type(model, credentials)
|
audio_type = self._get_model_audio_type(model, credentials)
|
||||||
tts_file_id = self._get_file_name(content_text)
|
tts_file_id = self._get_file_name(content_text)
|
||||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
file_path = f"generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}"
|
||||||
try:
|
try:
|
||||||
client = OpenAI(**credentials_kwargs)
|
client = OpenAI(**credentials_kwargs)
|
||||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
sentences = list(
|
||||||
|
self._split_text_into_sentences(text=content_text, limit=word_limit)
|
||||||
|
)
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
response = client.audio.speech.create(
|
||||||
|
model=model, voice=voice, input=sentence.strip()
|
||||||
|
)
|
||||||
# response.stream_to_file(file_path)
|
# response.stream_to_file(file_path)
|
||||||
storage.save(file_path, response.read())
|
storage.save(file_path, response.read())
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise InvokeBadRequestError(str(ex))
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
def _process_sentence(self, sentence: str, model: str,
|
def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
|
||||||
voice, credentials: dict):
|
|
||||||
"""
|
"""
|
||||||
_tts_invoke openai text2speech model api
|
_tts_invoke openai text2speech model api
|
||||||
|
|
||||||
@ -147,6 +203,8 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
|||||||
# transform credentials to kwargs for model instance
|
# transform credentials to kwargs for model instance
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
client = OpenAI(**credentials_kwargs)
|
client = OpenAI(**credentials_kwargs)
|
||||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
response = client.audio.speech.create(
|
||||||
|
model=model, voice=voice, input=sentence.strip()
|
||||||
|
)
|
||||||
if isinstance(response.read(), bytes):
|
if isinstance(response.read(), bytes):
|
||||||
return response.read()
|
return response.read()
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from model_providers.core.model_runtime.errors.invoke import (
|
from model_providers.core.model_runtime.errors.invoke import (
|
||||||
@ -35,10 +34,10 @@ class _CommonOAI_API_Compat:
|
|||||||
],
|
],
|
||||||
InvokeServerUnavailableError: [
|
InvokeServerUnavailableError: [
|
||||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||||
requests.exceptions.HTTPError # Server Error
|
requests.exceptions.HTTPError, # Server Error
|
||||||
],
|
],
|
||||||
InvokeConnectionError: [
|
InvokeConnectionError: [
|
||||||
requests.exceptions.ConnectTimeout, # Timeout
|
requests.exceptions.ConnectTimeout, # Timeout
|
||||||
requests.exceptions.ReadTimeout # Timeout
|
requests.exceptions.ReadTimeout, # Timeout
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,7 +8,12 @@ from urllib.parse import urljoin
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
from model_providers.core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from model_providers.core.model_runtime.entities.llm_entities import (
|
||||||
|
LLMMode,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.entities.message_entities import (
|
from model_providers.core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
@ -33,9 +38,15 @@ from model_providers.core.model_runtime.entities.model_entities import (
|
|||||||
PriceConfig,
|
PriceConfig,
|
||||||
)
|
)
|
||||||
from model_providers.core.model_runtime.errors.invoke import InvokeError
|
from model_providers.core.model_runtime.errors.invoke import InvokeError
|
||||||
from model_providers.core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from model_providers.core.model_runtime.errors.validate import (
|
||||||
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
CredentialsValidateFailedError,
|
||||||
from model_providers.core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.__base.large_language_model import (
|
||||||
|
LargeLanguageModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.model_providers.openai_api_compatible._common import (
|
||||||
|
_CommonOAI_API_Compat,
|
||||||
|
)
|
||||||
from model_providers.core.model_runtime.utils import helper
|
from model_providers.core.model_runtime.utils import helper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -46,11 +57,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
Model class for OpenAI large language model.
|
Model class for OpenAI large language model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
self,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
model: str,
|
||||||
stream: bool = True, user: Optional[str] = None) \
|
credentials: dict,
|
||||||
-> Union[LLMResult, Generator]:
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
@ -74,11 +91,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
user=user
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get number of tokens for given prompt messages
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
@ -99,78 +121,80 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
headers = {
|
headers = {"Content-Type": "application/json"}
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
api_key = credentials.get('api_key')
|
api_key = credentials.get("api_key")
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
endpoint_url = credentials['endpoint_url']
|
endpoint_url = credentials["endpoint_url"]
|
||||||
if not endpoint_url.endswith('/'):
|
if not endpoint_url.endswith("/"):
|
||||||
endpoint_url += '/'
|
endpoint_url += "/"
|
||||||
|
|
||||||
# prepare the payload for a simple ping to the model
|
# prepare the payload for a simple ping to the model
|
||||||
data = {
|
data = {"model": model, "max_tokens": 5}
|
||||||
'model': model,
|
|
||||||
'max_tokens': 5
|
|
||||||
}
|
|
||||||
|
|
||||||
completion_type = LLMMode.value_of(credentials['mode'])
|
completion_type = LLMMode.value_of(credentials["mode"])
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
data['messages'] = [
|
data["messages"] = [
|
||||||
{
|
{"role": "user", "content": "ping"},
|
||||||
"role": "user",
|
|
||||||
"content": "ping"
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
endpoint_url = urljoin(endpoint_url, "chat/completions")
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
elif completion_type is LLMMode.COMPLETION:
|
||||||
data['prompt'] = 'ping'
|
data["prompt"] = "ping"
|
||||||
endpoint_url = urljoin(endpoint_url, 'completions')
|
endpoint_url = urljoin(endpoint_url, "completions")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported completion type for model configuration.")
|
raise ValueError("Unsupported completion type for model configuration.")
|
||||||
|
|
||||||
# send a post request to validate the credentials
|
# send a post request to validate the credentials
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
endpoint_url,
|
endpoint_url, headers=headers, json=data, timeout=(10, 60)
|
||||||
headers=headers,
|
|
||||||
json=data,
|
|
||||||
timeout=(10, 60)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise CredentialsValidateFailedError(
|
raise CredentialsValidateFailedError(
|
||||||
f'Credentials validation failed with status code {response.status_code}')
|
f"Credentials validation failed with status code {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
json_result = response.json()
|
json_result = response.json()
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
raise CredentialsValidateFailedError(
|
||||||
|
"Credentials validation failed: JSON decode error"
|
||||||
|
)
|
||||||
|
|
||||||
if (completion_type is LLMMode.CHAT
|
if completion_type is LLMMode.CHAT and (
|
||||||
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
|
"object" not in json_result
|
||||||
|
or json_result["object"] != "chat.completion"
|
||||||
|
):
|
||||||
raise CredentialsValidateFailedError(
|
raise CredentialsValidateFailedError(
|
||||||
'Credentials validation failed: invalid response object, must be \'chat.completion\'')
|
"Credentials validation failed: invalid response object, must be 'chat.completion'"
|
||||||
elif (completion_type is LLMMode.COMPLETION
|
)
|
||||||
and ('object' not in json_result or json_result['object'] != 'text_completion')):
|
elif completion_type is LLMMode.COMPLETION and (
|
||||||
|
"object" not in json_result
|
||||||
|
or json_result["object"] != "text_completion"
|
||||||
|
):
|
||||||
raise CredentialsValidateFailedError(
|
raise CredentialsValidateFailedError(
|
||||||
'Credentials validation failed: invalid response object, must be \'text_completion\'')
|
"Credentials validation failed: invalid response object, must be 'text_completion'"
|
||||||
|
)
|
||||||
except CredentialsValidateFailedError:
|
except CredentialsValidateFailedError:
|
||||||
raise
|
raise
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
raise CredentialsValidateFailedError(
|
||||||
|
f"An error occurred during credentials validation: {str(ex)}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
def get_customizable_model_schema(
|
||||||
|
self, model: str, credentials: dict
|
||||||
|
) -> AIModelEntity:
|
||||||
"""
|
"""
|
||||||
generate custom model entities from credentials
|
generate custom model entities from credentials
|
||||||
"""
|
"""
|
||||||
support_function_call = False
|
support_function_call = False
|
||||||
features = []
|
features = []
|
||||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||||
if function_calling_type == 'function_call':
|
if function_calling_type == "function_call":
|
||||||
features = [ModelFeature.TOOL_CALL]
|
features = [ModelFeature.TOOL_CALL]
|
||||||
support_function_call = True
|
support_function_call = True
|
||||||
endpoint_url = credentials["endpoint_url"]
|
endpoint_url = credentials["endpoint_url"]
|
||||||
@ -185,43 +209,45 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
features=features if support_function_call else [],
|
features=features if support_function_call else [],
|
||||||
model_properties={
|
model_properties={
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
credentials.get("context_size", "4096")
|
||||||
|
),
|
||||||
|
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||||
},
|
},
|
||||||
parameter_rules=[
|
parameter_rules=[
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name=DefaultParameterName.TEMPERATURE.value,
|
name=DefaultParameterName.TEMPERATURE.value,
|
||||||
label=I18nObject(en_US="Temperature"),
|
label=I18nObject(en_US="Temperature"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
default=float(credentials.get('temperature', 0.7)),
|
default=float(credentials.get("temperature", 0.7)),
|
||||||
min=0,
|
min=0,
|
||||||
max=2,
|
max=2,
|
||||||
precision=2
|
precision=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name=DefaultParameterName.TOP_P.value,
|
name=DefaultParameterName.TOP_P.value,
|
||||||
label=I18nObject(en_US="Top P"),
|
label=I18nObject(en_US="Top P"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
default=float(credentials.get('top_p', 1)),
|
default=float(credentials.get("top_p", 1)),
|
||||||
min=0,
|
min=0,
|
||||||
max=1,
|
max=1,
|
||||||
precision=2
|
precision=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||||
label=I18nObject(en_US="Frequency Penalty"),
|
label=I18nObject(en_US="Frequency Penalty"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
default=float(credentials.get('frequency_penalty', 0)),
|
default=float(credentials.get("frequency_penalty", 0)),
|
||||||
min=-2,
|
min=-2,
|
||||||
max=2
|
max=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
||||||
label=I18nObject(en_US="Presence Penalty"),
|
label=I18nObject(en_US="Presence Penalty"),
|
||||||
type=ParameterType.FLOAT,
|
type=ParameterType.FLOAT,
|
||||||
default=float(credentials.get('presence_penalty', 0)),
|
default=float(credentials.get("presence_penalty", 0)),
|
||||||
min=-2,
|
min=-2,
|
||||||
max=2
|
max=2,
|
||||||
),
|
),
|
||||||
ParameterRule(
|
ParameterRule(
|
||||||
name=DefaultParameterName.MAX_TOKENS.value,
|
name=DefaultParameterName.MAX_TOKENS.value,
|
||||||
@ -229,31 +255,40 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
type=ParameterType.INT,
|
type=ParameterType.INT,
|
||||||
default=512,
|
default=512,
|
||||||
min=1,
|
min=1,
|
||||||
max=int(credentials.get('max_tokens_to_sample', 4096)),
|
max=int(credentials.get("max_tokens_to_sample", 4096)),
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
pricing=PriceConfig(
|
pricing=PriceConfig(
|
||||||
input=Decimal(credentials.get('input_price', 0)),
|
input=Decimal(credentials.get("input_price", 0)),
|
||||||
output=Decimal(credentials.get('output_price', 0)),
|
output=Decimal(credentials.get("output_price", 0)),
|
||||||
unit=Decimal(credentials.get('unit', 0)),
|
unit=Decimal(credentials.get("unit", 0)),
|
||||||
currency=credentials.get('currency', "USD")
|
currency=credentials.get("currency", "USD"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if credentials['mode'] == 'chat':
|
if credentials["mode"] == "chat":
|
||||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||||
elif credentials['mode'] == 'completion':
|
elif credentials["mode"] == "completion":
|
||||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
raise ValueError(
|
||||||
|
f"Unknown completion type {credentials['completion_type']}"
|
||||||
|
)
|
||||||
|
|
||||||
return entity
|
return entity
|
||||||
|
|
||||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
|
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
|
||||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
def _generate(
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
self,
|
||||||
stream: bool = True, \
|
model: str,
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke llm completion model
|
Invoke llm completion model
|
||||||
|
|
||||||
@ -267,50 +302,53 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
"Content-Type": "application/json",
|
||||||
'Accept-Charset': 'utf-8',
|
"Accept-Charset": "utf-8",
|
||||||
}
|
}
|
||||||
|
|
||||||
api_key = credentials.get('api_key')
|
api_key = credentials.get("api_key")
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
endpoint_url = credentials["endpoint_url"]
|
endpoint_url = credentials["endpoint_url"]
|
||||||
if not endpoint_url.endswith('/'):
|
if not endpoint_url.endswith("/"):
|
||||||
endpoint_url += '/'
|
endpoint_url += "/"
|
||||||
|
|
||||||
data = {
|
data = {"model": model, "stream": stream, **model_parameters}
|
||||||
"model": model,
|
|
||||||
"stream": stream,
|
|
||||||
**model_parameters
|
|
||||||
}
|
|
||||||
|
|
||||||
completion_type = LLMMode.value_of(credentials['mode'])
|
completion_type = LLMMode.value_of(credentials["mode"])
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
endpoint_url = urljoin(endpoint_url, "chat/completions")
|
||||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
data["messages"] = [
|
||||||
|
self._convert_prompt_message_to_dict(m) for m in prompt_messages
|
||||||
|
]
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
elif completion_type is LLMMode.COMPLETION:
|
||||||
endpoint_url = urljoin(endpoint_url, 'completions')
|
endpoint_url = urljoin(endpoint_url, "completions")
|
||||||
data['prompt'] = prompt_messages[0].content
|
data["prompt"] = prompt_messages[0].content
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported completion type for model configuration.")
|
raise ValueError("Unsupported completion type for model configuration.")
|
||||||
|
|
||||||
# annotate tools with names, descriptions, etc.
|
# annotate tools with names, descriptions, etc.
|
||||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||||
formatted_tools = []
|
formatted_tools = []
|
||||||
if tools:
|
if tools:
|
||||||
if function_calling_type == 'function_call':
|
if function_calling_type == "function_call":
|
||||||
data['functions'] = [{
|
data["functions"] = [
|
||||||
"name": tool.name,
|
{
|
||||||
"description": tool.description,
|
"name": tool.name,
|
||||||
"parameters": tool.parameters
|
"description": tool.description,
|
||||||
} for tool in tools]
|
"parameters": tool.parameters,
|
||||||
elif function_calling_type == 'tool_call':
|
}
|
||||||
|
for tool in tools
|
||||||
|
]
|
||||||
|
elif function_calling_type == "tool_call":
|
||||||
data["tool_choice"] = "auto"
|
data["tool_choice"] = "auto"
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
formatted_tools.append(
|
||||||
|
helper.dump_model(PromptMessageFunction(function=tool))
|
||||||
|
)
|
||||||
|
|
||||||
data["tools"] = formatted_tools
|
data["tools"] = formatted_tools
|
||||||
|
|
||||||
@ -321,26 +359,33 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
data["user"] = user
|
data["user"] = user
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
endpoint_url,
|
endpoint_url, headers=headers, json=data, timeout=(10, 60), stream=stream
|
||||||
headers=headers,
|
|
||||||
json=data,
|
|
||||||
timeout=(10, 60),
|
|
||||||
stream=stream
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.encoding is None or response.encoding == 'ISO-8859-1':
|
if response.encoding is None or response.encoding == "ISO-8859-1":
|
||||||
response.encoding = 'utf-8'
|
response.encoding = "utf-8"
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
raise InvokeError(
|
||||||
|
f"API request failed with status code {response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_stream_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(
|
||||||
|
model, credentials, response, prompt_messages
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
def _handle_generate_stream_response(
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: requests.Response,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm stream response
|
Handle llm stream response
|
||||||
|
|
||||||
@ -350,17 +395,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response chunk generator
|
:return: llm response chunk generator
|
||||||
"""
|
"""
|
||||||
full_assistant_content = ''
|
full_assistant_content = ""
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
|
|
||||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
def create_final_llm_result_chunk(
|
||||||
-> LLMResultChunk:
|
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||||
|
) -> LLMResultChunk:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
prompt_tokens = self._num_tokens_from_string(
|
||||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
model, prompt_messages[0].content
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_string(
|
||||||
|
model, full_assistant_content
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
return LLMResultChunk(
|
return LLMResultChunk(
|
||||||
model=model,
|
model=model,
|
||||||
@ -369,21 +421,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
index=index,
|
index=index,
|
||||||
message=message,
|
message=message,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage
|
usage=usage,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# delimiter for stream response, need unicode_escape
|
# delimiter for stream response, need unicode_escape
|
||||||
import codecs
|
import codecs
|
||||||
|
|
||||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||||
if chunk:
|
if chunk:
|
||||||
# ignore sse comments
|
# ignore sse comments
|
||||||
if chunk.startswith(':'):
|
if chunk.startswith(":"):
|
||||||
continue
|
continue
|
||||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
||||||
chunk_json = None
|
chunk_json = None
|
||||||
try:
|
try:
|
||||||
chunk_json = json.loads(decoded_chunk)
|
chunk_json = json.loads(decoded_chunk)
|
||||||
@ -392,45 +445,49 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
yield create_final_llm_result_chunk(
|
yield create_final_llm_result_chunk(
|
||||||
index=chunk_index + 1,
|
index=chunk_index + 1,
|
||||||
message=AssistantPromptMessage(content=""),
|
message=AssistantPromptMessage(content=""),
|
||||||
finish_reason="Non-JSON encountered."
|
finish_reason="Non-JSON encountered.",
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
choice = chunk_json['choices'][0]
|
choice = chunk_json["choices"][0]
|
||||||
finish_reason = chunk_json['choices'][0].get('finish_reason')
|
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
|
||||||
if 'delta' in choice:
|
if "delta" in choice:
|
||||||
delta = choice['delta']
|
delta = choice["delta"]
|
||||||
delta_content = delta.get('content')
|
delta_content = delta.get("content")
|
||||||
if delta_content is None or delta_content == '':
|
if delta_content is None or delta_content == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||||
# assistant_message_function_call = delta.delta.function_call
|
# assistant_message_function_call = delta.delta.function_call
|
||||||
|
|
||||||
# extract tool calls from response
|
# extract tool calls from response
|
||||||
if assistant_message_tool_calls:
|
if assistant_message_tool_calls:
|
||||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
tool_calls = self._extract_response_tool_calls(
|
||||||
|
assistant_message_tool_calls
|
||||||
|
)
|
||||||
# function_call = self._extract_response_function_call(assistant_message_function_call)
|
# function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||||
# tool_calls = [function_call] if function_call else []
|
# tool_calls = [function_call] if function_call else []
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
content=delta_content,
|
content=delta_content,
|
||||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
tool_calls=tool_calls if assistant_message_tool_calls else [],
|
||||||
)
|
)
|
||||||
|
|
||||||
full_assistant_content += delta_content
|
full_assistant_content += delta_content
|
||||||
elif 'text' in choice:
|
elif "text" in choice:
|
||||||
choice_text = choice.get('text', '')
|
choice_text = choice.get("text", "")
|
||||||
if choice_text == '':
|
if choice_text == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=choice_text
|
||||||
|
)
|
||||||
full_assistant_content += choice_text
|
full_assistant_content += choice_text
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
@ -440,7 +497,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
yield create_final_llm_result_chunk(
|
yield create_final_llm_result_chunk(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
finish_reason=finish_reason
|
finish_reason=finish_reason,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
@ -449,40 +506,50 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=assistant_prompt_message,
|
message=assistant_prompt_message,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
|
def _handle_generate_response(
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
response: requests.Response,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
) -> LLMResult:
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
completion_type = LLMMode.value_of(credentials['mode'])
|
completion_type = LLMMode.value_of(credentials["mode"])
|
||||||
|
|
||||||
output = response_json['choices'][0]
|
output = response_json["choices"][0]
|
||||||
|
|
||||||
response_content = ''
|
response_content = ""
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
function_calling_type = credentials.get("function_calling_type", "no_call")
|
||||||
if completion_type is LLMMode.CHAT:
|
if completion_type is LLMMode.CHAT:
|
||||||
response_content = output.get('message', {})['content']
|
response_content = output.get("message", {})["content"]
|
||||||
if function_calling_type == 'tool_call':
|
if function_calling_type == "tool_call":
|
||||||
tool_calls = output.get('message', {}).get('tool_calls')
|
tool_calls = output.get("message", {}).get("tool_calls")
|
||||||
elif function_calling_type == 'function_call':
|
elif function_calling_type == "function_call":
|
||||||
tool_calls = output.get('message', {}).get('function_call')
|
tool_calls = output.get("message", {}).get("function_call")
|
||||||
|
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
elif completion_type is LLMMode.COMPLETION:
|
||||||
response_content = output['text']
|
response_content = output["text"]
|
||||||
|
|
||||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
assistant_message = AssistantPromptMessage(
|
||||||
|
content=response_content, tool_calls=[]
|
||||||
|
)
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
if function_calling_type == 'tool_call':
|
if function_calling_type == "tool_call":
|
||||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
assistant_message.tool_calls = self._extract_response_tool_calls(
|
||||||
elif function_calling_type == 'function_call':
|
tool_calls
|
||||||
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
)
|
||||||
|
elif function_calling_type == "function_call":
|
||||||
|
assistant_message.tool_calls = [
|
||||||
|
self._extract_response_function_call(tool_calls)
|
||||||
|
]
|
||||||
|
|
||||||
usage = response_json.get("usage")
|
usage = response_json.get("usage")
|
||||||
if usage:
|
if usage:
|
||||||
@ -491,11 +558,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
completion_tokens = usage["completion_tokens"]
|
completion_tokens = usage["completion_tokens"]
|
||||||
else:
|
else:
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
prompt_tokens = self._num_tokens_from_string(
|
||||||
completion_tokens = self._num_tokens_from_string(model, assistant_message.content)
|
model, prompt_messages[0].content
|
||||||
|
)
|
||||||
|
completion_tokens = self._num_tokens_from_string(
|
||||||
|
model, assistant_message.content
|
||||||
|
)
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
usage = self._calc_response_usage(
|
||||||
|
model, credentials, prompt_tokens, completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
result = LLMResult(
|
result = LLMResult(
|
||||||
@ -522,17 +595,19 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
message_content = cast(PromptMessageContent, message_content)
|
message_content = cast(PromptMessageContent, message_content)
|
||||||
sub_message_dict = {
|
sub_message_dict = {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": message_content.data
|
"text": message_content.data,
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
message_content = cast(
|
||||||
|
ImagePromptMessageContent, message_content
|
||||||
|
)
|
||||||
sub_message_dict = {
|
sub_message_dict = {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": message_content.data,
|
"url": message_content.data,
|
||||||
"detail": message_content.detail.value
|
"detail": message_content.detail.value,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
|
|
||||||
@ -563,7 +638,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
message_dict = {
|
message_dict = {
|
||||||
"role": "function",
|
"role": "function",
|
||||||
"content": message.content,
|
"content": message.content,
|
||||||
"name": message.tool_call_id
|
"name": message.tool_call_id,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
@ -573,8 +648,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
def _num_tokens_from_string(self, model: str, text: str,
|
def _num_tokens_from_string(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Approximate num tokens for model with gpt2 tokenizer.
|
Approximate num tokens for model with gpt2 tokenizer.
|
||||||
|
|
||||||
@ -590,8 +666,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
def _num_tokens_from_messages(
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[PromptMessage],
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Approximate num tokens with GPT2 tokenizer.
|
Approximate num tokens with GPT2 tokenizer.
|
||||||
"""
|
"""
|
||||||
@ -610,10 +690,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
# which need to download the image and then get the resolution for calculation,
|
# which need to download the image and then get the resolution for calculation,
|
||||||
# and will increase the request delay
|
# and will increase the request delay
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
text = ''
|
text = ""
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item['type'] == 'text':
|
if isinstance(item, dict) and item["type"] == "text":
|
||||||
text += item['text']
|
text += item["text"]
|
||||||
|
|
||||||
value = text
|
value = text
|
||||||
|
|
||||||
@ -651,46 +731,46 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('type')
|
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('function')
|
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('function')
|
num_tokens += self._get_num_tokens_by_gpt2("function")
|
||||||
|
|
||||||
# calculate num tokens for function object
|
# calculate num tokens for function object
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('name')
|
num_tokens += self._get_num_tokens_by_gpt2("name")
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
|
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('description')
|
num_tokens += self._get_num_tokens_by_gpt2("description")
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
|
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
|
||||||
parameters = tool.parameters
|
parameters = tool.parameters
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('parameters')
|
num_tokens += self._get_num_tokens_by_gpt2("parameters")
|
||||||
if 'title' in parameters:
|
if "title" in parameters:
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('title')
|
num_tokens += self._get_num_tokens_by_gpt2("title")
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
|
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('type')
|
num_tokens += self._get_num_tokens_by_gpt2("type")
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
|
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
|
||||||
if 'properties' in parameters:
|
if "properties" in parameters:
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('properties')
|
num_tokens += self._get_num_tokens_by_gpt2("properties")
|
||||||
for key, value in parameters.get('properties').items():
|
for key, value in parameters.get("properties").items():
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(key)
|
num_tokens += self._get_num_tokens_by_gpt2(key)
|
||||||
for field_key, field_value in value.items():
|
for field_key, field_value in value.items():
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||||
if field_key == 'enum':
|
if field_key == "enum":
|
||||||
for enum_field in field_value:
|
for enum_field in field_value:
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
|
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
|
||||||
else:
|
else:
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
|
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
|
||||||
if 'required' in parameters:
|
if "required" in parameters:
|
||||||
num_tokens += self._get_num_tokens_by_gpt2('required')
|
num_tokens += self._get_num_tokens_by_gpt2("required")
|
||||||
for required_field in parameters['required']:
|
for required_field in parameters["required"]:
|
||||||
num_tokens += 3
|
num_tokens += 3
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def _extract_response_tool_calls(self,
|
def _extract_response_tool_calls(
|
||||||
response_tool_calls: list[dict]) \
|
self, response_tool_calls: list[dict]
|
||||||
-> list[AssistantPromptMessage.ToolCall]:
|
) -> list[AssistantPromptMessage.ToolCall]:
|
||||||
"""
|
"""
|
||||||
Extract tool calls from response
|
Extract tool calls from response
|
||||||
|
|
||||||
@ -702,20 +782,21 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
for response_tool_call in response_tool_calls:
|
for response_tool_call in response_tool_calls:
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_tool_call["function"]["name"],
|
name=response_tool_call["function"]["name"],
|
||||||
arguments=response_tool_call["function"]["arguments"]
|
arguments=response_tool_call["function"]["arguments"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=response_tool_call["id"],
|
id=response_tool_call["id"],
|
||||||
type=response_tool_call["type"],
|
type=response_tool_call["type"],
|
||||||
function=function
|
function=function,
|
||||||
)
|
)
|
||||||
tool_calls.append(tool_call)
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
def _extract_response_function_call(self, response_function_call) \
|
def _extract_response_function_call(
|
||||||
-> AssistantPromptMessage.ToolCall:
|
self, response_function_call
|
||||||
|
) -> AssistantPromptMessage.ToolCall:
|
||||||
"""
|
"""
|
||||||
Extract function call from response
|
Extract function call from response
|
||||||
|
|
||||||
@ -725,14 +806,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
tool_call = None
|
tool_call = None
|
||||||
if response_function_call:
|
if response_function_call:
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
name=response_function_call['name'],
|
name=response_function_call["name"],
|
||||||
arguments=response_function_call['arguments']
|
arguments=response_function_call["arguments"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
id=response_function_call['name'],
|
id=response_function_call["name"], type="function", function=function
|
||||||
type="function",
|
|
||||||
function=function
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_call
|
return tool_call
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user